Commit 876a36a4 authored by raojy's avatar raojy
Browse files

first

parent eda2afb8
[project]
name = "SenseNova-SI"
version = "0.1.0"
description = "Scaling Spatial Intelligence with Multimodal Foundation Models"
readme = "README.md"
requires-python = ">=3.11"
keywords = ["computer vision", "multimodal", "spatial intelligence", "MLLM"]
dependencies = [
"transformers>=4.57.0",
"Pillow",
"numpy",
"setuptools",
"einops>=0.8.1",
"timm>=1.0.22",
"accelerate>=1.11.0",
"opencv-python>=4.11.0.86",
]
[dependency-groups]
flash-attn = ["flash-attn==2.7.4.post1"]
dev = ["ruff==0.14.4"]
[project.optional-dependencies]
cu118 = ["torch>=2.4.0", "torchvision"]
cu121 = ["torch>=2.4.0", "torchvision"]
cu124 = ["torch>=2.4.0", "torchvision"]
cu126 = ["torch>=2.4.0", "torchvision"]
cu128 = ["torch>=2.4.0", "torchvision"]
cu129 = ["torch>=2.4.0", "torchvision"]
[tool.uv]
default-groups = ["flash-attn"]
no-build-isolation-package = ['flash-attn', 'setuptools']
conflicts = [
[
{ extra = "cu118" },
{ extra = "cu121" },
{ extra = "cu124" },
{ extra = "cu126" },
{ extra = "cu128" },
{ extra = "cu129" },
],
]
index = [
{ name = "pytorch-cu118", url = "https://download.pytorch.org/whl/cu118", explicit = true },
{ name = "pytorch-cu121", url = "https://download.pytorch.org/whl/cu121", explicit = true },
{ name = "pytorch-cu124", url = "https://download.pytorch.org/whl/cu124", explicit = true },
{ name = "pytorch-cu126", url = "https://download.pytorch.org/whl/cu126", explicit = true },
{ name = "pytorch-cu128", url = "https://download.pytorch.org/whl/cu128", explicit = true },
{ name = "pytorch-cu129", url = "https://download.pytorch.org/whl/cu129", explicit = true },
]
[tool.uv.sources]
torch = [
{ index = "pytorch-cu118", extra = "cu118" },
{ index = "pytorch-cu121", extra = "cu121" },
{ index = "pytorch-cu124", extra = "cu124" },
{ index = "pytorch-cu126", extra = "cu126" },
{ index = "pytorch-cu128", extra = "cu128" },
{ index = "pytorch-cu129", extra = "cu129" },
]
torchvision = [
{ index = "pytorch-cu118", extra = "cu118" },
{ index = "pytorch-cu121", extra = "cu121" },
{ index = "pytorch-cu124", extra = "cu124" },
{ index = "pytorch-cu126", extra = "cu126" },
{ index = "pytorch-cu128", extra = "cu128" },
{ index = "pytorch-cu129", extra = "cu129" },
]
from .bagel import SenseNovaSIBagelModel
from .internvl import SenseNovaSIInternVLModel
from .qwen import SenseNovaSIQwenModel
def get_default_model_type(model_path):
if "qwen" in model_path.lower():
return "qwen"
elif "internvl" in model_path.lower():
return "internvl"
elif "bagel" in model_path.lower():
return "bagel"
else:
raise ValueError(f"Unknown model type for {model_path}")
def get_model(model_path, model_type="auto"):
if model_type == "auto":
model_type = get_default_model_type(model_path)
if model_type == "qwen":
return SenseNovaSIQwenModel(model_path)
elif model_type == "internvl":
return SenseNovaSIInternVLModel(model_path)
elif model_type == "bagel":
return SenseNovaSIBagelModel(model_path)
else:
raise ValueError(f"Unknown model type: {model_type}")
__all__ = [
"get_default_model_type",
"get_model",
"SenseNovaSIInternVLModel",
"SenseNovaSIQwenModel",
"SenseNovaSIBagelModel",
]
import os
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from accelerate import (
infer_auto_device_map,
init_empty_weights,
load_checkpoint_and_dispatch,
)
from huggingface_hub import snapshot_download
from PIL import Image
from .bagel_utils.data.transforms import ImageTransform
from .bagel_utils.inferencer import InterleaveInferencer
from .bagel_utils.modeling.autoencoder import load_ae
from .bagel_utils.modeling.bagel import (
Bagel,
BagelConfig,
Qwen2Config,
Qwen2ForCausalLM,
SiglipVisionConfig,
SiglipVisionModel,
)
from .bagel_utils.modeling.qwen2 import Qwen2Tokenizer
from .model import Model
from .utils import add_special_tokens
BASE_PARAMS: Dict[str, Dict[str, Any]] = {
"generate": dict(
cfg_text_scale=4.0,
cfg_img_scale=1.0,
cfg_interval=[0.4, 1.0],
timestep_shift=3.0,
num_timesteps=50,
cfg_renorm_min=1.0,
cfg_renorm_type="global",
),
"think_generate": dict(
max_think_token_n=1000,
do_sample=False,
cfg_text_scale=4.0,
cfg_img_scale=1.0,
cfg_interval=[0.4, 1.0],
timestep_shift=3.0,
num_timesteps=50,
cfg_renorm_min=1.0,
cfg_renorm_type="global",
think=True,
),
"edit": dict(
cfg_text_scale=4.0,
cfg_img_scale=2.0,
cfg_interval=[0.0, 1.0],
timestep_shift=3.0,
num_timesteps=50,
cfg_renorm_min=0.0,
cfg_renorm_type="text_channel",
),
"think_edit": dict(
max_think_token_n=1000,
do_sample=False,
cfg_text_scale=4.0,
cfg_img_scale=2.0,
cfg_interval=[0.0, 1.0],
timestep_shift=3.0,
num_timesteps=50,
cfg_renorm_min=0.0,
cfg_renorm_type="text_channel",
think=True,
),
"understanding": dict(
max_think_token_n=1000,
do_sample=False,
understanding_output=True,
),
"think_understanding": dict(
max_think_token_n=1000,
do_sample=False,
understanding_output=True,
think=True,
),
}
class SenseNovaSIBagelModel(Model):
def __init__(
self,
model_path="sensenova/SenseNova-SI-1.1-BAGEL-7B-MoT",
generation_config: dict[str, Any] | str | os.PathLike | None = None,
mode="understanding",
out_img_dir="./output_images/test_bagel/",
dtype: str = "bf16",
):
super().__init__(generation_config)
# 1. Parse params
self.precision = dtype
if os.path.exists(model_path):
cache_path = model_path
else:
cache_path = snapshot_download(repo_id=model_path)
self.model_path = cache_path
self.checkpoint_path = os.path.join(self.model_path, "model.safetensors")
# Bagel mode
env_mode = os.getenv("BAGEL_MODE")
mode = env_mode.strip() if env_mode and env_mode.strip() else mode
if mode not in BASE_PARAMS:
raise ValueError(
f"Invalid mode '{mode}'. "
f"Bagel Supported modes: {list(BASE_PARAMS.keys())}"
)
self.mode = mode
env_out_img_dir = os.getenv("BAGEL_OUT_IMG_DIR")
self.out_img_dir = (
env_out_img_dir.strip()
if env_out_img_dir and env_out_img_dir.strip()
else out_img_dir
)
msg = (
f"[Bagel] mode = '{self.mode}' "
f"(can be overridden with env var BAGEL_MODE); "
f"out_img_dir = '{self.out_img_dir}' "
f"(can be overridden with env var BAGEL_OUT_IMG_DIR)"
)
print(msg)
# 2. Build model
model, vae_model, tokenizer, new_token_ids, vit_transform, vae_transform = (
self._build_model()
)
# 3. Load Checkpoint
model = self._load_model_weights(model)
# 4. Build inferencer
self.tokenizer = tokenizer
self.new_token_ids = new_token_ids
self.vit_transform = vit_transform
self.inferencer = InterleaveInferencer(
model=model,
vae_model=vae_model,
tokenizer=tokenizer,
vae_transform=vae_transform,
vit_transform=vit_transform,
new_token_ids=new_token_ids,
)
torch.cuda.empty_cache()
def _build_model(self):
# build llm config
llm_config = Qwen2Config.from_json_file(
os.path.join(self.model_path, "llm_config.json")
)
llm_config.qk_norm = True
llm_config.tie_word_embeddings = False
llm_config.layer_module = "Qwen2MoTDecoderLayer"
# build vit config
vit_config = SiglipVisionConfig.from_json_file(
os.path.join(self.model_path, "vit_config.json")
)
vit_config.rope = False
vit_config.num_hidden_layers -= 1
vit_transform = ImageTransform(980, 224, 14)
vae_transform = ImageTransform(1024, 512, 16)
# build vae config
vae_model, vae_config = load_ae(
local_path=os.path.join(self.model_path, "ae.safetensors")
)
# build tokenizer
tokenizer = Qwen2Tokenizer.from_pretrained(self.model_path)
tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
# build model
model_config = BagelConfig(
visual_gen=True,
visual_und=True,
llm_config=llm_config,
vit_config=vit_config,
vae_config=vae_config,
latent_patch_size=2,
max_latent_size=64,
vit_max_num_patch_per_side=70,
connector_act="gelu_pytorch_tanh",
)
with init_empty_weights():
language_model = Qwen2ForCausalLM(llm_config)
vit_model = SiglipVisionModel(vit_config)
model = Bagel(language_model, vit_model, model_config)
model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config)
return model, vae_model, tokenizer, new_token_ids, vit_transform, vae_transform
def _load_model_weights(self, model):
device_map = infer_auto_device_map(
model, no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"]
)
same_device_modules = [
"language_model.model.embed_tokens",
"time_embedder",
"latent_pos_embed",
"vae2llm",
"llm2vae",
"connector",
"vit_pos_embed",
]
if torch.cuda.device_count() == 1:
first_device = device_map.get(same_device_modules[0], "cuda:0")
for k in same_device_modules:
if k in device_map:
device_map[k] = first_device
else:
device_map[k] = "cuda:0"
else:
first_device = device_map.get(same_device_modules[0])
for k in same_device_modules:
if k in device_map:
device_map[k] = first_device
if self.precision == "bf16":
model = load_checkpoint_and_dispatch(
model,
checkpoint=self.checkpoint_path,
device_map=device_map,
offload_buffers=True,
offload_folder="offload",
dtype=torch.bfloat16,
force_hooks=True,
).eval()
elif self.precision == "nf4":
from accelerate.utils import BnbQuantizationConfig, load_and_quantize_model
model = load_and_quantize_model(
model,
weights_location=self.checkpoint_path,
bnb_quantization_config=BnbQuantizationConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
),
device_map=device_map,
offload_folder="offload",
).eval()
elif self.precision == "int8":
from accelerate.utils import BnbQuantizationConfig, load_and_quantize_model
model = load_and_quantize_model(
model,
weights_location=self.checkpoint_path,
bnb_quantization_config=BnbQuantizationConfig(
load_in_8bit=True, torch_dtype=torch.float32
),
device_map=device_map,
offload_folder="offload",
).eval()
else:
raise NotImplementedError(f"Unsupported precision: {self.precision}")
return model
def _save_output_image(
self,
image: Image.Image,
mode: str,
img_path: Optional[str],
) -> str:
if image is None:
raise ValueError(
f"[OutputError] Mode={mode} expected an image output, but got None."
)
root = Path(self.out_img_dir)
images_root = root / (f"images")
images_root.mkdir(parents=True, exist_ok=True)
if mode in ["edit", "think_edit"]:
if img_path:
src = Path(img_path)
parent_name = src.parent.name or "default"
out_dir = images_root / parent_name
out_dir.mkdir(parents=True, exist_ok=True)
filename = src.name
else:
out_dir = images_root / "edit"
out_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d-%H%M%S-%f")
base = "sample"
filename = f"{base}_edit_{ts}_{uuid.uuid4().hex[:8]}.jpg"
out_path = out_dir / filename
elif mode in ["generate", "think_generate"]:
out_dir = images_root
out_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d-%H%M%S-%f")
base = "sample"
filename = f"{base}_{ts}_{uuid.uuid4().hex[:8]}.jpg"
out_path = out_dir / filename
else:
raise ValueError(f"[OutputError] Unexpected mode for image saving: {mode}")
image.save(out_path)
return str(out_path)
def generate(self, question: str, images: list[str] | None = None, **kwargs):
mode = self.mode
images = images or []
# Auto-prepend <image> placeholders if the question doesn't contain them
existing_count = question.count("<image>")
if images and existing_count == 0:
question = "".join(["<image>\n" for _ in images]) + question
text_parts = question.split("<image>")
if len(text_parts) != len(images) + 1:
raise ValueError(f"Text iamge tokens and number of images not match! ")
input_lists = []
input_img_paths = []
for i, part in enumerate(text_parts):
text = part.strip()
if text:
input_lists.append(text)
if i < len(images):
img_path = images[i]
try:
image = Image.open(img_path)
input_lists.append(image)
input_img_paths.append(img_path)
except Exception as e:
raise RuntimeError(f"Can not load image {img_path}: {e}") from e
params = dict(BASE_PARAMS[mode])
understanding_output_flag = params.pop("understanding_output", False)
think_flag = params.pop("think", False)
res = self.inferencer.interleave_inference(
input_lists=input_lists,
think=think_flag,
understanding_output=understanding_output_flag,
**params,
)
ret = {"image": [], "text": []}
for i in res:
if isinstance(i, Image.Image):
ret["image"].append(i)
elif isinstance(i, str):
ret["text"].append(i)
img_cnt, txt_cnt = len(ret["image"]), len(ret["text"])
if img_cnt + txt_cnt != 1:
print(
f"[Warning] You are using {mode} mode, so the output has {img_cnt} images and {txt_cnt} texts"
)
if txt_cnt > 0:
print(f"[Warning] The text output is: {ret['text'][0]}")
ret["image"] = ret["image"][0] if img_cnt else None
ret["text"] = ret["text"][0] if txt_cnt else None
if mode in ["edit", "think_edit", "generate", "think_generate"]:
if ret["image"] is not None:
if len(input_img_paths) == 1:
ref_img_path = input_img_paths[0]
else:
ref_img_path = None
img_path_out = self._save_output_image(
image=ret["image"],
mode=mode,
img_path=ref_img_path,
)
ret["image"] = img_path_out
res = img_path_out
else:
res = None
else:
res = ret["text"]
return res
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import math
import random
import torch
from PIL import Image
# from torch.nn.attention.flex_attention import or_masks, and_masks
def create_sparse_mask(document_lens, split_lens, attn_modes, device):
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
def full_and_noise_mask(b, h, q_idx, kv_idx):
return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (
full_and_noise_seq_id[q_idx] >= 0
)
def remove_noise_mask(b, h, q_idx, kv_idx):
return ~(
(noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx])
)
def sample_mask(b, h, q_idx, kv_idx):
return document_id[q_idx] == document_id[kv_idx]
full_and_noise_tmp = []
noise_tmp = []
for i, (length, model) in enumerate(zip(split_lens, attn_modes)):
value = i if model in ["full", "noise"] else -1
full_and_noise_tmp.extend([value] * length)
value_noise = i if model == "noise" else -1
noise_tmp.extend([value_noise] * length)
full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device)
noise_seq_id = torch.Tensor(noise_tmp).to(device)
document_id = torch.cat(
[torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]
).to(device)
return and_masks(
or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask
)
def patchify(image, patch_size):
p = patch_size
c, h, w = image.shape
assert h % p == 0 and w % p == 0
image = image.reshape(c, h // p, p, w // p, p)
image = torch.einsum("chpwq->hwpqc", image)
image = image.reshape(-1, p**2 * c)
return image
def get_flattened_position_ids_extrapolate(
img_h, img_w, patch_size, max_num_patches_per_side
):
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
coords_h = torch.arange(0, num_patches_h)
coords_w = torch.arange(0, num_patches_w)
pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten()
return pos_ids
def get_flattened_position_ids_interpolate(
img_h, img_w, patch_size, max_num_patches_per_side
):
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
boundaries = torch.arange(
1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side
)
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
pos_ids = (
bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w
).flatten()
return pos_ids
def prepare_attention_mask_per_sample(split_lens, attn_modes, device="cpu"):
"""
nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within
a sample, where each sample contains multiple splits with different attn modes.
nested_attn_modes: whether to use full attn in each split.
"""
sample_len = sum(split_lens)
attention_mask = torch.zeros(
(sample_len, sample_len), dtype=torch.bool, device=device
)
csum = 0
for s, attn_mode in zip(split_lens, attn_modes):
assert attn_mode in ["causal", "full", "noise"]
if attn_mode == "causal":
attention_mask[csum : csum + s, csum : csum + s] = torch.ones(
(s, s), device=device
).tril()
attention_mask[csum : csum + s, :csum] = 1
else:
attention_mask[csum : csum + s, csum : csum + s] = torch.ones((s, s))
attention_mask[csum : csum + s, :csum] = 1
csum += s
csum = 0
for s, attn_mode in zip(split_lens, attn_modes):
if attn_mode == "noise":
attention_mask[:, csum : csum + s] = torch.zeros((sample_len, s))
attention_mask[csum : csum + s, csum : csum + s] = torch.ones((s, s))
csum += s
attention_mask = torch.zeros_like(attention_mask, dtype=torch.float).masked_fill_(
~attention_mask, float("-inf")
)
return attention_mask
def split_integer_exp_decay(S, ng_sample_decay=1.0):
if ng_sample_decay == 1.0:
N = random.randint(1, S)
else:
base = (1 - ng_sample_decay) / (1 - math.pow(ng_sample_decay, S))
p = [base * math.pow(ng_sample_decay, i) for i in range(S)]
N = random.choices(list(range(1, S + 1)), p, k=1)[0]
cumsum = [0] + sorted(random.sample(range(1, S), N - 1)) + [S]
result = [cumsum[i + 1] - cumsum[i] for i in range(len(cumsum) - 1)]
return result, cumsum
def pil_img2rgb(image):
if image.mode == "RGBA" or image.info.get("transparency", None) is not None:
image = image.convert("RGBA")
white = Image.new(mode="RGB", size=image.size, color=(255, 255, 255))
white.paste(image, mask=image.split()[3])
image = white
else:
image = image.convert("RGB")
return image
def add_special_tokens(tokenizer):
all_special_tokens = []
for k, v in tokenizer.special_tokens_map.items():
if isinstance(v, str):
all_special_tokens.append(v)
elif isinstance(v, list):
all_special_tokens += v
new_tokens = []
if "<|im_start|>" not in all_special_tokens:
new_tokens.append("<|im_start|>")
if "<|im_end|>" not in all_special_tokens:
new_tokens.append("<|im_end|>")
if "<|vision_start|>" not in all_special_tokens:
new_tokens.append("<|vision_start|>")
if "<|vision_end|>" not in all_special_tokens:
new_tokens.append("<|vision_end|>")
num_new_tokens = tokenizer.add_tokens(new_tokens)
bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
start_of_image = tokenizer.convert_tokens_to_ids("<|vision_start|>")
end_of_image = tokenizer.convert_tokens_to_ids("<|vision_end|>")
new_token_ids = dict(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
start_of_image=start_of_image,
end_of_image=end_of_image,
)
return tokenizer, new_token_ids, num_new_tokens
def len2weight(x, loss_reduction="square"):
if x == 0:
return x
if loss_reduction == "token":
return 1
if loss_reduction == "sample":
return 1 / x
if loss_reduction == "square":
return 1 / (x**0.5)
raise NotImplementedError(loss_reduction)
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import random
import cv2
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as F
class MaxLongEdgeMinShortEdgeResize(torch.nn.Module):
"""Resize the input image so that its longest side and shortest side are within a specified range,
ensuring that both sides are divisible by a specified stride.
Args:
max_size (int): Maximum size for the longest edge of the image.
min_size (int): Minimum size for the shortest edge of the image.
stride (int): Value by which the height and width of the image must be divisible.
max_pixels (int): Maximum pixels for the full image.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
``InterpolationMode.BILINEAR``, and ``InterpolationMode.BICUBIC`` are supported.
The corresponding Pillow integer constants, e.g., ``PIL.Image.BILINEAR`` are also accepted.
antialias (bool, optional): Whether to apply antialiasing (default is True).
"""
def __init__(
self,
max_size: int,
min_size: int,
stride: int,
max_pixels: int,
interpolation=InterpolationMode.BICUBIC,
antialias=True,
):
super().__init__()
self.max_size = max_size
self.min_size = min_size
self.stride = stride
self.max_pixels = max_pixels
self.interpolation = interpolation
self.antialias = antialias
def _make_divisible(self, value, stride):
"""Ensure the value is divisible by the stride."""
return max(stride, int(round(value / stride) * stride))
def _apply_scale(self, width, height, scale):
new_width = round(width * scale)
new_height = round(height * scale)
new_width = self._make_divisible(new_width, self.stride)
new_height = self._make_divisible(new_height, self.stride)
return new_width, new_height
def forward(self, img, img_num=1):
"""
Args:
img (PIL Image): Image to be resized.
img_num (int): Number of images, used to change max_tokens.
Returns:
PIL Image or Tensor: Rescaled image with divisible dimensions.
"""
if isinstance(img, torch.Tensor):
height, width = img.shape[-2:]
else:
width, height = img.size
scale = min(self.max_size / max(width, height), 1.0)
scale = max(scale, self.min_size / min(width, height))
new_width, new_height = self._apply_scale(width, height, scale)
# Ensure the number of pixels does not exceed max_pixels
if new_width * new_height > self.max_pixels / img_num:
scale = self.max_pixels / img_num / (new_width * new_height)
new_width, new_height = self._apply_scale(new_width, new_height, scale)
# Ensure longest edge does not exceed max_size
if max(new_width, new_height) > self.max_size:
scale = self.max_size / max(new_width, new_height)
new_width, new_height = self._apply_scale(new_width, new_height, scale)
return F.resize(
img, (new_height, new_width), self.interpolation, antialias=self.antialias
)
class ImageTransform:
def __init__(
self,
max_image_size,
min_image_size,
image_stride,
max_pixels=14 * 14 * 9 * 1024,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
self.stride = image_stride
self.resize_transform = MaxLongEdgeMinShortEdgeResize(
max_size=max_image_size,
min_size=min_image_size,
stride=image_stride,
max_pixels=max_pixels,
)
self.to_tensor_transform = transforms.ToTensor()
self.normalize_transform = transforms.Normalize(
mean=image_mean, std=image_std, inplace=True
)
def __call__(self, img, img_num=1):
img = self.resize_transform(img, img_num=img_num)
img = self.to_tensor_transform(img)
img = self.normalize_transform(img)
return img
def decolorization(image):
gray_image = image.convert("L")
return (
Image.merge(image.mode, [gray_image] * 3)
if image.mode in ("RGB", "L")
else gray_image
)
def downscale(image, scale_factor):
new_width = int(round(image.width * scale_factor))
new_height = int(round(image.height * scale_factor))
new_width = max(1, new_width)
new_height = max(1, new_height)
return image.resize((new_width, new_height), resample=Image.BICUBIC)
def crop(image, crop_factors):
target_h, target_w = crop_factors
img_w, img_h = image.size
if target_h > img_h or target_w > img_w:
raise ValueError("Crop size exceeds image dimensions")
x = random.randint(0, img_w - target_w)
y = random.randint(0, img_h - target_h)
return image.crop((x, y, x + target_w, y + target_h)), [
[x, y],
[x + target_w, y + target_h],
]
def motion_blur_opencv(image, kernel_size=15, angle=0):
# 线性核
kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
kernel[kernel_size // 2, :] = np.ones(kernel_size, dtype=np.float32)
# 旋转核
center = (kernel_size / 2 - 0.5, kernel_size / 2 - 0.5)
M = cv2.getRotationMatrix2D(center, angle, 1)
rotated_kernel = cv2.warpAffine(kernel, M, (kernel_size, kernel_size))
# 归一化核
rotated_kernel /= rotated_kernel.sum() if rotated_kernel.sum() != 0 else 1
img = np.array(image)
if img.ndim == 2:
blurred = cv2.filter2D(img, -1, rotated_kernel, borderType=cv2.BORDER_REFLECT)
else:
# 对于彩色图像,各通道独立卷积
blurred = np.zeros_like(img)
for c in range(img.shape[2]):
blurred[..., c] = cv2.filter2D(
img[..., c], -1, rotated_kernel, borderType=cv2.BORDER_REFLECT
)
return Image.fromarray(blurred.astype(np.uint8))
def shuffle_patch(image, num_splits, gap_size=2):
"""将图像分割为块(允许尺寸不整除),随机打乱后拼接,块间保留间隙"""
h_splits, w_splits = num_splits
img_w, img_h = image.size
base_patch_h = img_h // h_splits
patch_heights = [base_patch_h] * (h_splits - 1)
patch_heights.append(img_h - sum(patch_heights))
base_patch_w = img_w // w_splits
patch_widths = [base_patch_w] * (w_splits - 1)
patch_widths.append(img_w - sum(patch_widths))
patches = []
current_y = 0
for i in range(h_splits):
current_x = 0
patch_h = patch_heights[i]
for j in range(w_splits):
patch_w = patch_widths[j]
patch = image.crop(
(current_x, current_y, current_x + patch_w, current_y + patch_h)
)
patches.append(patch)
current_x += patch_w
current_y += patch_h
random.shuffle(patches)
total_width = sum(patch_widths) + (w_splits - 1) * gap_size
total_height = sum(patch_heights) + (h_splits - 1) * gap_size
new_image = Image.new(
image.mode, (total_width, total_height), color=(255, 255, 255)
)
current_y = 0 # 当前行的起始 Y 坐标
patch_idx = 0 # 当前处理的块索引
for i in range(h_splits):
current_x = 0 # 当前列的起始 X 坐标
patch_h = patch_heights[i] # 当前行块的高度
for j in range(w_splits):
# 取出打乱后的块
patch = patches[patch_idx]
patch_w = patch_widths[j] # 当前列块的宽度
# 粘贴块(左上角坐标为 (current_x, current_y))
new_image.paste(patch, (current_x, current_y))
# 更新 X 坐标(下一个块的起始位置 = 当前块宽度 + 间隙)
current_x += patch_w + gap_size
patch_idx += 1
# 更新 Y 坐标(下一行的起始位置 = 当前行高度 + 间隙)
current_y += patch_h + gap_size
return new_image
def inpainting(image, num_splits, blank_ratio=0.3, blank_color=(255, 255, 255)):
"""
图像分割后随机空白部分patch,用于inpainting任务
参数:
image: PIL.Image 输入图像(RGB模式)
h_splits: int 行分割数(垂直方向分割块数)
w_splits: int 列分割数(水平方向分割块数)
blank_ratio: float 空白patch的比例(0~1)
blank_color: tuple 空白区域的颜色(RGB,如白色(255,255,255))
返回:
PIL.Image 处理后拼接的图像
"""
h_splits, w_splits = num_splits
img_w, img_h = image.size
base_patch_h = img_h // h_splits
patch_heights = [base_patch_h] * (h_splits - 1)
patch_heights.append(img_h - sum(patch_heights))
base_patch_w = img_w // w_splits
patch_widths = [base_patch_w] * (w_splits - 1)
patch_widths.append(img_w - sum(patch_widths))
patches = []
current_y = 0
for i in range(h_splits):
current_x = 0
patch_h = patch_heights[i]
for j in range(w_splits):
patch_w = patch_widths[j]
patch = image.crop(
(current_x, current_y, current_x + patch_w, current_y + patch_h)
)
patches.append(patch)
current_x += patch_w
current_y += patch_h
total_patches = h_splits * w_splits
num_blank = int(total_patches * blank_ratio)
num_blank = max(0, min(num_blank, total_patches))
blank_indices = random.sample(range(total_patches), num_blank)
processed_patches = []
for idx, patch in enumerate(patches):
if idx in blank_indices:
blank_patch = Image.new("RGB", patch.size, color=blank_color)
processed_patches.append(blank_patch)
else:
processed_patches.append(patch)
# 创建结果图像(尺寸与原图一致)
result_image = Image.new("RGB", (img_w, img_h))
current_y = 0
patch_idx = 0
for i in range(h_splits):
current_x = 0
patch_h = patch_heights[i]
for j in range(w_splits):
# 取出处理后的patch
patch = processed_patches[patch_idx]
patch_w = patch_widths[j]
# 粘贴到原位置
result_image.paste(patch, (current_x, current_y))
current_x += patch_w
patch_idx += 1
current_y += patch_h
return result_image
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
import torch
from PIL import Image
from .data.data_utils import pil_img2rgb
from .modeling.bagel.qwen2_navit import NaiveCache
VLM_THINK_SYSTEM_PROMPT = """You should first think about the reasoning process in the mind and then provide the user with the answer.
The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here"""
GEN_THINK_SYSTEM_PROMPT = """You should first think about the planning process in the mind and then generate the image.
The planning process is enclosed within <think> </think> tags, i.e. <think> planning process here </think> image here"""
class InterleaveInferencer:
def __init__(
self, model, vae_model, tokenizer, vae_transform, vit_transform, new_token_ids
):
self.model = model
self.vae_model = vae_model
self.tokenizer = tokenizer
self.vae_transform = vae_transform
self.vit_transform = vit_transform
self.new_token_ids = new_token_ids
def init_gen_context(self):
gen_context = {
"kv_lens": [0],
"ropes": [0],
"past_key_values": NaiveCache(
self.model.config.llm_config.num_hidden_layers
),
}
return gen_context
@torch.no_grad()
def update_context_text(self, text, gen_context):
# used for interleave data, currently only support 1 data inference,
past_key_values = gen_context["past_key_values"]
kv_lens = gen_context["kv_lens"]
ropes = gen_context["ropes"]
generation_input, kv_lens, ropes = self.model.prepare_prompts(
curr_kvlens=kv_lens,
curr_rope=ropes,
prompts=[text],
tokenizer=self.tokenizer,
new_token_ids=self.new_token_ids,
)
past_key_values = self.model.forward_cache_update_text(
past_key_values, **generation_input
)
gen_context["kv_lens"] = kv_lens
gen_context["ropes"] = ropes
gen_context["past_key_values"] = past_key_values
return gen_context
@torch.no_grad()
def update_context_image(self, image, gen_context, vae=True, vit=True):
# used for interleave data, currently only support 1 data inference,
assert vae or vit
past_key_values = gen_context["past_key_values"]
kv_lens = gen_context["kv_lens"]
ropes = gen_context["ropes"]
if vae:
## update vae
generation_input, kv_lens, ropes = self.model.prepare_vae_images(
curr_kvlens=kv_lens,
curr_rope=ropes,
images=[image],
transforms=self.vae_transform,
new_token_ids=self.new_token_ids,
)
past_key_values = self.model.forward_cache_update_vae(
self.vae_model, past_key_values, **generation_input
)
if vit:
## update vit
generation_input, kv_lens, ropes = self.model.prepare_vit_images(
curr_kvlens=kv_lens,
curr_rope=ropes,
images=[image],
transforms=self.vit_transform,
new_token_ids=self.new_token_ids,
)
past_key_values = self.model.forward_cache_update_vit(
past_key_values, **generation_input
)
gen_context["kv_lens"] = kv_lens
gen_context["ropes"] = ropes
gen_context["past_key_values"] = past_key_values
return gen_context
@torch.no_grad()
def gen_image(
self,
image_shape,
gen_context,
cfg_text_scale=4.0,
cfg_img_scale=1.5,
cfg_text_precontext=None,
cfg_img_precontext=None,
cfg_interval=(0.4, 1.0),
cfg_renorm_min=0.0,
cfg_renorm_type="global",
num_timesteps=50,
timestep_shift=3.0,
enable_taylorseer=False,
):
# print(cfg_renorm_type)
past_key_values = gen_context["past_key_values"]
kv_lens = gen_context["kv_lens"]
ropes = gen_context["ropes"]
generation_input = self.model.prepare_vae_latent(
curr_kvlens=kv_lens,
curr_rope=ropes,
image_sizes=[image_shape],
new_token_ids=self.new_token_ids,
)
# text cfg
cfg_text_past_key_values = cfg_text_precontext["past_key_values"]
kv_lens_cfg = cfg_text_precontext["kv_lens"]
ropes_cfg = cfg_text_precontext["ropes"]
generation_input_cfg_text = self.model.prepare_vae_latent_cfg(
curr_kvlens=kv_lens_cfg,
curr_rope=ropes_cfg,
image_sizes=[image_shape],
)
# img cfg
cfg_img_past_key_values = cfg_img_precontext["past_key_values"]
kv_lens_cfg = cfg_img_precontext["kv_lens"]
ropes_cfg = cfg_img_precontext["ropes"]
generation_input_cfg_img = self.model.prepare_vae_latent_cfg(
curr_kvlens=kv_lens_cfg,
curr_rope=ropes_cfg,
image_sizes=[image_shape],
)
unpacked_latent = self.model.generate_image(
past_key_values=past_key_values,
cfg_text_past_key_values=cfg_text_past_key_values,
cfg_img_past_key_values=cfg_img_past_key_values,
num_timesteps=num_timesteps,
cfg_text_scale=cfg_text_scale,
cfg_img_scale=cfg_img_scale,
cfg_interval=cfg_interval,
cfg_renorm_min=cfg_renorm_min,
cfg_renorm_type=cfg_renorm_type,
timestep_shift=timestep_shift,
**generation_input,
cfg_text_packed_position_ids=generation_input_cfg_text[
"cfg_packed_position_ids"
],
cfg_text_packed_query_indexes=generation_input_cfg_text[
"cfg_packed_query_indexes"
],
cfg_text_key_values_lens=generation_input_cfg_text["cfg_key_values_lens"],
cfg_text_packed_key_value_indexes=generation_input_cfg_text[
"cfg_packed_key_value_indexes"
],
cfg_img_packed_position_ids=generation_input_cfg_img[
"cfg_packed_position_ids"
],
cfg_img_packed_query_indexes=generation_input_cfg_img[
"cfg_packed_query_indexes"
],
cfg_img_key_values_lens=generation_input_cfg_img["cfg_key_values_lens"],
cfg_img_packed_key_value_indexes=generation_input_cfg_img[
"cfg_packed_key_value_indexes"
],
enable_taylorseer=enable_taylorseer,
)
image = self.decode_image(unpacked_latent[0], image_shape)
return image
def decode_image(self, latent, image_shape):
H, W = image_shape
h, w = H // self.model.latent_downsample, W // self.model.latent_downsample
latent = latent.reshape(
1,
h,
w,
self.model.latent_patch_size,
self.model.latent_patch_size,
self.model.latent_channel,
)
latent = torch.einsum("nhwpqc->nchpwq", latent)
latent = latent.reshape(
1,
self.model.latent_channel,
h * self.model.latent_patch_size,
w * self.model.latent_patch_size,
)
image = self.vae_model.decode(latent)
image = (image * 0.5 + 0.5).clamp(0, 1)[0].permute(1, 2, 0) * 255
image = Image.fromarray((image).to(torch.uint8).cpu().numpy())
return image
@torch.no_grad()
def gen_text(
self,
gen_context,
max_length: int = 500,
do_sample: bool = True,
temperature: float = 1.0,
):
gen_context = deepcopy(gen_context)
past_key_values = gen_context["past_key_values"]
kv_lens = gen_context["kv_lens"]
ropes = gen_context["ropes"]
generation_input = self.model.prepare_start_tokens(
kv_lens, ropes, self.new_token_ids
)
unpacked_latent = self.model.generate_text(
past_key_values=past_key_values,
max_length=max_length,
do_sample=do_sample,
temperature=temperature,
end_token_id=self.new_token_ids["eos_token_id"],
**generation_input,
)
output = self.tokenizer.decode(unpacked_latent[:, 0])
output = output.split("<|im_end|>")[0].split("<|im_start|>")[1]
return output
@torch.no_grad()
def interleave_inference(
self,
input_lists: List[Union[str, Image.Image]],
think=False,
understanding_output=False,
max_think_token_n=1000,
do_sample=False,
text_temperature=0.3,
cfg_text_scale=3.0,
cfg_img_scale=1.5,
cfg_interval=[0.4, 1.0],
timestep_shift=3.0,
num_timesteps=50,
cfg_renorm_min=0.0,
cfg_renorm_type="global",
image_shapes=(1024, 1024),
enable_taylorseer=False,
) -> List[Union[str, Image.Image]]:
output_list = []
gen_context = self.init_gen_context()
cfg_text_context = deepcopy(gen_context)
cfg_img_context = deepcopy(gen_context)
with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
if think:
if understanding_output:
system_prompt = VLM_THINK_SYSTEM_PROMPT
else:
system_prompt = GEN_THINK_SYSTEM_PROMPT
gen_context = self.update_context_text(system_prompt, gen_context)
cfg_img_context = self.update_context_text(
system_prompt, cfg_img_context
)
for input_term in input_lists:
if isinstance(input_term, str):
cfg_text_context = deepcopy(gen_context)
gen_context = self.update_context_text(input_term, gen_context)
cfg_img_context = self.update_context_text(
input_term, cfg_img_context
)
elif isinstance(input_term, Image.Image):
input_term = self.vae_transform.resize_transform(
pil_img2rgb(input_term)
)
gen_context = self.update_context_image(
input_term, gen_context, vae=not understanding_output
)
image_shapes = input_term.size[::-1]
cfg_text_context = deepcopy(gen_context)
else:
raise ValueError(f"Unsupported input type: {type(input_term)}")
if understanding_output:
gen_text = self.gen_text(
gen_context,
do_sample=do_sample,
temperature=text_temperature,
max_length=max_think_token_n,
)
output_list.append(gen_text)
else:
if think:
gen_text = self.gen_text(
gen_context,
do_sample=do_sample,
temperature=text_temperature,
max_length=max_think_token_n,
)
gen_context = self.update_context_text(gen_text, gen_context)
output_list.append(gen_text)
img = self.gen_image(
image_shapes,
gen_context,
cfg_text_precontext=cfg_text_context,
cfg_img_precontext=cfg_img_context,
cfg_text_scale=cfg_text_scale,
cfg_img_scale=cfg_img_scale,
cfg_interval=cfg_interval,
timestep_shift=timestep_shift,
num_timesteps=num_timesteps,
cfg_renorm_min=cfg_renorm_min,
cfg_renorm_type=cfg_renorm_type,
enable_taylorseer=enable_taylorseer,
)
output_list.append(img)
return output_list
def __call__(
self, image: Optional[Image.Image] = None, text: Optional[str] = None, **kargs
) -> Dict[str, Any]:
output_dict = {"image": None, "text": None}
if image is None and text is None:
print("Please provide at least one input: either an image or text.")
return output_dict
input_list = []
if image is not None:
input_list.append(image)
if text is not None:
input_list.append(text)
output_list = self.interleave_inference(input_list, **kargs)
for i in output_list:
if isinstance(i, Image.Image):
output_dict["image"] = i
elif isinstance(i, str):
output_dict["text"] = i
return output_dict
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
from . import autoencoder, bagel, qwen2, siglip
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
from .bagel import Bagel, BagelConfig
from .qwen2_navit import Qwen2Config, Qwen2ForCausalLM, Qwen2Model
from .siglip_navit import SiglipVisionConfig, SiglipVisionModel
__all__ = [
"BagelConfig",
"Bagel",
"Qwen2Config",
"Qwen2Model",
"Qwen2ForCausalLM",
"SiglipVisionConfig",
"SiglipVisionModel",
]
# Copyright (c) 2022 Facebook, Inc. and its affiliates.
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: CC BY-NC 4.0
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under CC BY-NC 4.0, with the full license text
# available at https://github.com/facebookresearch/DiT/blob/main/LICENSE.txt.
#
# This modified file is released under the same license.
import math
import numpy as np
import torch
from torch import nn
from transformers.activations import ACT2FN
# --------------------------------------------------------
# 2D sine-cosine position embedding
# References:
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
# --------------------------------------------------------
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate(
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
# --------------------------------------------------------
# TimestepEmbedder
# Reference:
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
# --------------------------------------------------------
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class MLPconnector(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_act: str):
super().__init__()
self.activation_fn = ACT2FN[hidden_act]
self.fc1 = nn.Linear(in_dim, out_dim)
self.fc2 = nn.Linear(out_dim, out_dim)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class PositionEmbedding(nn.Module):
def __init__(self, max_num_patch_per_side, hidden_size):
super().__init__()
self.max_num_patch_per_side = max_num_patch_per_side
self.hidden_size = hidden_size
self.pos_embed = nn.Parameter(
torch.zeros(max_num_patch_per_side**2, hidden_size), requires_grad=False
)
self._init_weights()
def _init_weights(self):
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed = get_2d_sincos_pos_embed(
self.hidden_size, self.max_num_patch_per_side
)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float())
def forward(self, position_ids):
return self.pos_embed[position_ids]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment