Commit 876a36a4 authored by raojy's avatar raojy
Browse files

first

parent eda2afb8
[project]
name = "SenseNova-SI"
version = "0.1.0"
description = "Scaling Spatial Intelligence with Multimodal Foundation Models"
readme = "README.md"
requires-python = ">=3.11"
keywords = ["computer vision", "multimodal", "spatial intelligence", "MLLM"]
dependencies = [
"transformers>=4.57.0",
"Pillow",
"numpy",
"setuptools",
"einops>=0.8.1",
"timm>=1.0.22",
"accelerate>=1.11.0",
"opencv-python>=4.11.0.86",
]
[dependency-groups]
flash-attn = ["flash-attn==2.7.4.post1"]
dev = ["ruff==0.14.4"]
[project.optional-dependencies]
cu118 = ["torch>=2.4.0", "torchvision"]
cu121 = ["torch>=2.4.0", "torchvision"]
cu124 = ["torch>=2.4.0", "torchvision"]
cu126 = ["torch>=2.4.0", "torchvision"]
cu128 = ["torch>=2.4.0", "torchvision"]
cu129 = ["torch>=2.4.0", "torchvision"]
[tool.uv]
default-groups = ["flash-attn"]
no-build-isolation-package = ['flash-attn', 'setuptools']
conflicts = [
[
{ extra = "cu118" },
{ extra = "cu121" },
{ extra = "cu124" },
{ extra = "cu126" },
{ extra = "cu128" },
{ extra = "cu129" },
],
]
index = [
{ name = "pytorch-cu118", url = "https://download.pytorch.org/whl/cu118", explicit = true },
{ name = "pytorch-cu121", url = "https://download.pytorch.org/whl/cu121", explicit = true },
{ name = "pytorch-cu124", url = "https://download.pytorch.org/whl/cu124", explicit = true },
{ name = "pytorch-cu126", url = "https://download.pytorch.org/whl/cu126", explicit = true },
{ name = "pytorch-cu128", url = "https://download.pytorch.org/whl/cu128", explicit = true },
{ name = "pytorch-cu129", url = "https://download.pytorch.org/whl/cu129", explicit = true },
]
[tool.uv.sources]
torch = [
{ index = "pytorch-cu118", extra = "cu118" },
{ index = "pytorch-cu121", extra = "cu121" },
{ index = "pytorch-cu124", extra = "cu124" },
{ index = "pytorch-cu126", extra = "cu126" },
{ index = "pytorch-cu128", extra = "cu128" },
{ index = "pytorch-cu129", extra = "cu129" },
]
torchvision = [
{ index = "pytorch-cu118", extra = "cu118" },
{ index = "pytorch-cu121", extra = "cu121" },
{ index = "pytorch-cu124", extra = "cu124" },
{ index = "pytorch-cu126", extra = "cu126" },
{ index = "pytorch-cu128", extra = "cu128" },
{ index = "pytorch-cu129", extra = "cu129" },
]
from .bagel import SenseNovaSIBagelModel
from .internvl import SenseNovaSIInternVLModel
from .qwen import SenseNovaSIQwenModel
def get_default_model_type(model_path):
if "qwen" in model_path.lower():
return "qwen"
elif "internvl" in model_path.lower():
return "internvl"
elif "bagel" in model_path.lower():
return "bagel"
else:
raise ValueError(f"Unknown model type for {model_path}")
def get_model(model_path, model_type="auto"):
if model_type == "auto":
model_type = get_default_model_type(model_path)
if model_type == "qwen":
return SenseNovaSIQwenModel(model_path)
elif model_type == "internvl":
return SenseNovaSIInternVLModel(model_path)
elif model_type == "bagel":
return SenseNovaSIBagelModel(model_path)
else:
raise ValueError(f"Unknown model type: {model_type}")
__all__ = [
"get_default_model_type",
"get_model",
"SenseNovaSIInternVLModel",
"SenseNovaSIQwenModel",
"SenseNovaSIBagelModel",
]
import os
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from accelerate import (
infer_auto_device_map,
init_empty_weights,
load_checkpoint_and_dispatch,
)
from huggingface_hub import snapshot_download
from PIL import Image
from .bagel_utils.data.transforms import ImageTransform
from .bagel_utils.inferencer import InterleaveInferencer
from .bagel_utils.modeling.autoencoder import load_ae
from .bagel_utils.modeling.bagel import (
Bagel,
BagelConfig,
Qwen2Config,
Qwen2ForCausalLM,
SiglipVisionConfig,
SiglipVisionModel,
)
from .bagel_utils.modeling.qwen2 import Qwen2Tokenizer
from .model import Model
from .utils import add_special_tokens
BASE_PARAMS: Dict[str, Dict[str, Any]] = {
"generate": dict(
cfg_text_scale=4.0,
cfg_img_scale=1.0,
cfg_interval=[0.4, 1.0],
timestep_shift=3.0,
num_timesteps=50,
cfg_renorm_min=1.0,
cfg_renorm_type="global",
),
"think_generate": dict(
max_think_token_n=1000,
do_sample=False,
cfg_text_scale=4.0,
cfg_img_scale=1.0,
cfg_interval=[0.4, 1.0],
timestep_shift=3.0,
num_timesteps=50,
cfg_renorm_min=1.0,
cfg_renorm_type="global",
think=True,
),
"edit": dict(
cfg_text_scale=4.0,
cfg_img_scale=2.0,
cfg_interval=[0.0, 1.0],
timestep_shift=3.0,
num_timesteps=50,
cfg_renorm_min=0.0,
cfg_renorm_type="text_channel",
),
"think_edit": dict(
max_think_token_n=1000,
do_sample=False,
cfg_text_scale=4.0,
cfg_img_scale=2.0,
cfg_interval=[0.0, 1.0],
timestep_shift=3.0,
num_timesteps=50,
cfg_renorm_min=0.0,
cfg_renorm_type="text_channel",
think=True,
),
"understanding": dict(
max_think_token_n=1000,
do_sample=False,
understanding_output=True,
),
"think_understanding": dict(
max_think_token_n=1000,
do_sample=False,
understanding_output=True,
think=True,
),
}
class SenseNovaSIBagelModel(Model):
def __init__(
self,
model_path="sensenova/SenseNova-SI-1.1-BAGEL-7B-MoT",
generation_config: dict[str, Any] | str | os.PathLike | None = None,
mode="understanding",
out_img_dir="./output_images/test_bagel/",
dtype: str = "bf16",
):
super().__init__(generation_config)
# 1. Parse params
self.precision = dtype
if os.path.exists(model_path):
cache_path = model_path
else:
cache_path = snapshot_download(repo_id=model_path)
self.model_path = cache_path
self.checkpoint_path = os.path.join(self.model_path, "model.safetensors")
# Bagel mode
env_mode = os.getenv("BAGEL_MODE")
mode = env_mode.strip() if env_mode and env_mode.strip() else mode
if mode not in BASE_PARAMS:
raise ValueError(
f"Invalid mode '{mode}'. "
f"Bagel Supported modes: {list(BASE_PARAMS.keys())}"
)
self.mode = mode
env_out_img_dir = os.getenv("BAGEL_OUT_IMG_DIR")
self.out_img_dir = (
env_out_img_dir.strip()
if env_out_img_dir and env_out_img_dir.strip()
else out_img_dir
)
msg = (
f"[Bagel] mode = '{self.mode}' "
f"(can be overridden with env var BAGEL_MODE); "
f"out_img_dir = '{self.out_img_dir}' "
f"(can be overridden with env var BAGEL_OUT_IMG_DIR)"
)
print(msg)
# 2. Build model
model, vae_model, tokenizer, new_token_ids, vit_transform, vae_transform = (
self._build_model()
)
# 3. Load Checkpoint
model = self._load_model_weights(model)
# 4. Build inferencer
self.tokenizer = tokenizer
self.new_token_ids = new_token_ids
self.vit_transform = vit_transform
self.inferencer = InterleaveInferencer(
model=model,
vae_model=vae_model,
tokenizer=tokenizer,
vae_transform=vae_transform,
vit_transform=vit_transform,
new_token_ids=new_token_ids,
)
torch.cuda.empty_cache()
def _build_model(self):
# build llm config
llm_config = Qwen2Config.from_json_file(
os.path.join(self.model_path, "llm_config.json")
)
llm_config.qk_norm = True
llm_config.tie_word_embeddings = False
llm_config.layer_module = "Qwen2MoTDecoderLayer"
# build vit config
vit_config = SiglipVisionConfig.from_json_file(
os.path.join(self.model_path, "vit_config.json")
)
vit_config.rope = False
vit_config.num_hidden_layers -= 1
vit_transform = ImageTransform(980, 224, 14)
vae_transform = ImageTransform(1024, 512, 16)
# build vae config
vae_model, vae_config = load_ae(
local_path=os.path.join(self.model_path, "ae.safetensors")
)
# build tokenizer
tokenizer = Qwen2Tokenizer.from_pretrained(self.model_path)
tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
# build model
model_config = BagelConfig(
visual_gen=True,
visual_und=True,
llm_config=llm_config,
vit_config=vit_config,
vae_config=vae_config,
latent_patch_size=2,
max_latent_size=64,
vit_max_num_patch_per_side=70,
connector_act="gelu_pytorch_tanh",
)
with init_empty_weights():
language_model = Qwen2ForCausalLM(llm_config)
vit_model = SiglipVisionModel(vit_config)
model = Bagel(language_model, vit_model, model_config)
model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config)
return model, vae_model, tokenizer, new_token_ids, vit_transform, vae_transform
def _load_model_weights(self, model):
device_map = infer_auto_device_map(
model, no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"]
)
same_device_modules = [
"language_model.model.embed_tokens",
"time_embedder",
"latent_pos_embed",
"vae2llm",
"llm2vae",
"connector",
"vit_pos_embed",
]
if torch.cuda.device_count() == 1:
first_device = device_map.get(same_device_modules[0], "cuda:0")
for k in same_device_modules:
if k in device_map:
device_map[k] = first_device
else:
device_map[k] = "cuda:0"
else:
first_device = device_map.get(same_device_modules[0])
for k in same_device_modules:
if k in device_map:
device_map[k] = first_device
if self.precision == "bf16":
model = load_checkpoint_and_dispatch(
model,
checkpoint=self.checkpoint_path,
device_map=device_map,
offload_buffers=True,
offload_folder="offload",
dtype=torch.bfloat16,
force_hooks=True,
).eval()
elif self.precision == "nf4":
from accelerate.utils import BnbQuantizationConfig, load_and_quantize_model
model = load_and_quantize_model(
model,
weights_location=self.checkpoint_path,
bnb_quantization_config=BnbQuantizationConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=False,
bnb_4bit_quant_type="nf4",
),
device_map=device_map,
offload_folder="offload",
).eval()
elif self.precision == "int8":
from accelerate.utils import BnbQuantizationConfig, load_and_quantize_model
model = load_and_quantize_model(
model,
weights_location=self.checkpoint_path,
bnb_quantization_config=BnbQuantizationConfig(
load_in_8bit=True, torch_dtype=torch.float32
),
device_map=device_map,
offload_folder="offload",
).eval()
else:
raise NotImplementedError(f"Unsupported precision: {self.precision}")
return model
def _save_output_image(
self,
image: Image.Image,
mode: str,
img_path: Optional[str],
) -> str:
if image is None:
raise ValueError(
f"[OutputError] Mode={mode} expected an image output, but got None."
)
root = Path(self.out_img_dir)
images_root = root / (f"images")
images_root.mkdir(parents=True, exist_ok=True)
if mode in ["edit", "think_edit"]:
if img_path:
src = Path(img_path)
parent_name = src.parent.name or "default"
out_dir = images_root / parent_name
out_dir.mkdir(parents=True, exist_ok=True)
filename = src.name
else:
out_dir = images_root / "edit"
out_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d-%H%M%S-%f")
base = "sample"
filename = f"{base}_edit_{ts}_{uuid.uuid4().hex[:8]}.jpg"
out_path = out_dir / filename
elif mode in ["generate", "think_generate"]:
out_dir = images_root
out_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d-%H%M%S-%f")
base = "sample"
filename = f"{base}_{ts}_{uuid.uuid4().hex[:8]}.jpg"
out_path = out_dir / filename
else:
raise ValueError(f"[OutputError] Unexpected mode for image saving: {mode}")
image.save(out_path)
return str(out_path)
def generate(self, question: str, images: list[str] | None = None, **kwargs):
mode = self.mode
images = images or []
# Auto-prepend <image> placeholders if the question doesn't contain them
existing_count = question.count("<image>")
if images and existing_count == 0:
question = "".join(["<image>\n" for _ in images]) + question
text_parts = question.split("<image>")
if len(text_parts) != len(images) + 1:
raise ValueError(f"Text iamge tokens and number of images not match! ")
input_lists = []
input_img_paths = []
for i, part in enumerate(text_parts):
text = part.strip()
if text:
input_lists.append(text)
if i < len(images):
img_path = images[i]
try:
image = Image.open(img_path)
input_lists.append(image)
input_img_paths.append(img_path)
except Exception as e:
raise RuntimeError(f"Can not load image {img_path}: {e}") from e
params = dict(BASE_PARAMS[mode])
understanding_output_flag = params.pop("understanding_output", False)
think_flag = params.pop("think", False)
res = self.inferencer.interleave_inference(
input_lists=input_lists,
think=think_flag,
understanding_output=understanding_output_flag,
**params,
)
ret = {"image": [], "text": []}
for i in res:
if isinstance(i, Image.Image):
ret["image"].append(i)
elif isinstance(i, str):
ret["text"].append(i)
img_cnt, txt_cnt = len(ret["image"]), len(ret["text"])
if img_cnt + txt_cnt != 1:
print(
f"[Warning] You are using {mode} mode, so the output has {img_cnt} images and {txt_cnt} texts"
)
if txt_cnt > 0:
print(f"[Warning] The text output is: {ret['text'][0]}")
ret["image"] = ret["image"][0] if img_cnt else None
ret["text"] = ret["text"][0] if txt_cnt else None
if mode in ["edit", "think_edit", "generate", "think_generate"]:
if ret["image"] is not None:
if len(input_img_paths) == 1:
ref_img_path = input_img_paths[0]
else:
ref_img_path = None
img_path_out = self._save_output_image(
image=ret["image"],
mode=mode,
img_path=ref_img_path,
)
ret["image"] = img_path_out
res = img_path_out
else:
res = None
else:
res = ret["text"]
return res
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import math
import random
import torch
from PIL import Image
# from torch.nn.attention.flex_attention import or_masks, and_masks
def create_sparse_mask(document_lens, split_lens, attn_modes, device):
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
def full_and_noise_mask(b, h, q_idx, kv_idx):
return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (
full_and_noise_seq_id[q_idx] >= 0
)
def remove_noise_mask(b, h, q_idx, kv_idx):
return ~(
(noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx])
)
def sample_mask(b, h, q_idx, kv_idx):
return document_id[q_idx] == document_id[kv_idx]
full_and_noise_tmp = []
noise_tmp = []
for i, (length, model) in enumerate(zip(split_lens, attn_modes)):
value = i if model in ["full", "noise"] else -1
full_and_noise_tmp.extend([value] * length)
value_noise = i if model == "noise" else -1
noise_tmp.extend([value_noise] * length)
full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device)
noise_seq_id = torch.Tensor(noise_tmp).to(device)
document_id = torch.cat(
[torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]
).to(device)
return and_masks(
or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask
)
def patchify(image, patch_size):
p = patch_size
c, h, w = image.shape
assert h % p == 0 and w % p == 0
image = image.reshape(c, h // p, p, w // p, p)
image = torch.einsum("chpwq->hwpqc", image)
image = image.reshape(-1, p**2 * c)
return image
def get_flattened_position_ids_extrapolate(
img_h, img_w, patch_size, max_num_patches_per_side
):
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
coords_h = torch.arange(0, num_patches_h)
coords_w = torch.arange(0, num_patches_w)
pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten()
return pos_ids
def get_flattened_position_ids_interpolate(
img_h, img_w, patch_size, max_num_patches_per_side
):
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
boundaries = torch.arange(
1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side
)
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
pos_ids = (
bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w
).flatten()
return pos_ids
def prepare_attention_mask_per_sample(split_lens, attn_modes, device="cpu"):
"""
nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within
a sample, where each sample contains multiple splits with different attn modes.
nested_attn_modes: whether to use full attn in each split.
"""
sample_len = sum(split_lens)
attention_mask = torch.zeros(
(sample_len, sample_len), dtype=torch.bool, device=device
)
csum = 0
for s, attn_mode in zip(split_lens, attn_modes):
assert attn_mode in ["causal", "full", "noise"]
if attn_mode == "causal":
attention_mask[csum : csum + s, csum : csum + s] = torch.ones(
(s, s), device=device
).tril()
attention_mask[csum : csum + s, :csum] = 1
else:
attention_mask[csum : csum + s, csum : csum + s] = torch.ones((s, s))
attention_mask[csum : csum + s, :csum] = 1
csum += s
csum = 0
for s, attn_mode in zip(split_lens, attn_modes):
if attn_mode == "noise":
attention_mask[:, csum : csum + s] = torch.zeros((sample_len, s))
attention_mask[csum : csum + s, csum : csum + s] = torch.ones((s, s))
csum += s
attention_mask = torch.zeros_like(attention_mask, dtype=torch.float).masked_fill_(
~attention_mask, float("-inf")
)
return attention_mask
def split_integer_exp_decay(S, ng_sample_decay=1.0):
if ng_sample_decay == 1.0:
N = random.randint(1, S)
else:
base = (1 - ng_sample_decay) / (1 - math.pow(ng_sample_decay, S))
p = [base * math.pow(ng_sample_decay, i) for i in range(S)]
N = random.choices(list(range(1, S + 1)), p, k=1)[0]
cumsum = [0] + sorted(random.sample(range(1, S), N - 1)) + [S]
result = [cumsum[i + 1] - cumsum[i] for i in range(len(cumsum) - 1)]
return result, cumsum
def pil_img2rgb(image):
if image.mode == "RGBA" or image.info.get("transparency", None) is not None:
image = image.convert("RGBA")
white = Image.new(mode="RGB", size=image.size, color=(255, 255, 255))
white.paste(image, mask=image.split()[3])
image = white
else:
image = image.convert("RGB")
return image
def add_special_tokens(tokenizer):
all_special_tokens = []
for k, v in tokenizer.special_tokens_map.items():
if isinstance(v, str):
all_special_tokens.append(v)
elif isinstance(v, list):
all_special_tokens += v
new_tokens = []
if "<|im_start|>" not in all_special_tokens:
new_tokens.append("<|im_start|>")
if "<|im_end|>" not in all_special_tokens:
new_tokens.append("<|im_end|>")
if "<|vision_start|>" not in all_special_tokens:
new_tokens.append("<|vision_start|>")
if "<|vision_end|>" not in all_special_tokens:
new_tokens.append("<|vision_end|>")
num_new_tokens = tokenizer.add_tokens(new_tokens)
bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
start_of_image = tokenizer.convert_tokens_to_ids("<|vision_start|>")
end_of_image = tokenizer.convert_tokens_to_ids("<|vision_end|>")
new_token_ids = dict(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
start_of_image=start_of_image,
end_of_image=end_of_image,
)
return tokenizer, new_token_ids, num_new_tokens
def len2weight(x, loss_reduction="square"):
if x == 0:
return x
if loss_reduction == "token":
return 1
if loss_reduction == "sample":
return 1 / x
if loss_reduction == "square":
return 1 / (x**0.5)
raise NotImplementedError(loss_reduction)
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import random
import cv2
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as F
class MaxLongEdgeMinShortEdgeResize(torch.nn.Module):
"""Resize the input image so that its longest side and shortest side are within a specified range,
ensuring that both sides are divisible by a specified stride.
Args:
max_size (int): Maximum size for the longest edge of the image.
min_size (int): Minimum size for the shortest edge of the image.
stride (int): Value by which the height and width of the image must be divisible.
max_pixels (int): Maximum pixels for the full image.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
``InterpolationMode.BILINEAR``, and ``InterpolationMode.BICUBIC`` are supported.
The corresponding Pillow integer constants, e.g., ``PIL.Image.BILINEAR`` are also accepted.
antialias (bool, optional): Whether to apply antialiasing (default is True).
"""
def __init__(
self,
max_size: int,
min_size: int,
stride: int,
max_pixels: int,
interpolation=InterpolationMode.BICUBIC,
antialias=True,
):
super().__init__()
self.max_size = max_size
self.min_size = min_size
self.stride = stride
self.max_pixels = max_pixels
self.interpolation = interpolation
self.antialias = antialias
def _make_divisible(self, value, stride):
"""Ensure the value is divisible by the stride."""
return max(stride, int(round(value / stride) * stride))
def _apply_scale(self, width, height, scale):
new_width = round(width * scale)
new_height = round(height * scale)
new_width = self._make_divisible(new_width, self.stride)
new_height = self._make_divisible(new_height, self.stride)
return new_width, new_height
def forward(self, img, img_num=1):
"""
Args:
img (PIL Image): Image to be resized.
img_num (int): Number of images, used to change max_tokens.
Returns:
PIL Image or Tensor: Rescaled image with divisible dimensions.
"""
if isinstance(img, torch.Tensor):
height, width = img.shape[-2:]
else:
width, height = img.size
scale = min(self.max_size / max(width, height), 1.0)
scale = max(scale, self.min_size / min(width, height))
new_width, new_height = self._apply_scale(width, height, scale)
# Ensure the number of pixels does not exceed max_pixels
if new_width * new_height > self.max_pixels / img_num:
scale = self.max_pixels / img_num / (new_width * new_height)
new_width, new_height = self._apply_scale(new_width, new_height, scale)
# Ensure longest edge does not exceed max_size
if max(new_width, new_height) > self.max_size:
scale = self.max_size / max(new_width, new_height)
new_width, new_height = self._apply_scale(new_width, new_height, scale)
return F.resize(
img, (new_height, new_width), self.interpolation, antialias=self.antialias
)
class ImageTransform:
def __init__(
self,
max_image_size,
min_image_size,
image_stride,
max_pixels=14 * 14 * 9 * 1024,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
self.stride = image_stride
self.resize_transform = MaxLongEdgeMinShortEdgeResize(
max_size=max_image_size,
min_size=min_image_size,
stride=image_stride,
max_pixels=max_pixels,
)
self.to_tensor_transform = transforms.ToTensor()
self.normalize_transform = transforms.Normalize(
mean=image_mean, std=image_std, inplace=True
)
def __call__(self, img, img_num=1):
img = self.resize_transform(img, img_num=img_num)
img = self.to_tensor_transform(img)
img = self.normalize_transform(img)
return img
def decolorization(image):
gray_image = image.convert("L")
return (
Image.merge(image.mode, [gray_image] * 3)
if image.mode in ("RGB", "L")
else gray_image
)
def downscale(image, scale_factor):
new_width = int(round(image.width * scale_factor))
new_height = int(round(image.height * scale_factor))
new_width = max(1, new_width)
new_height = max(1, new_height)
return image.resize((new_width, new_height), resample=Image.BICUBIC)
def crop(image, crop_factors):
target_h, target_w = crop_factors
img_w, img_h = image.size
if target_h > img_h or target_w > img_w:
raise ValueError("Crop size exceeds image dimensions")
x = random.randint(0, img_w - target_w)
y = random.randint(0, img_h - target_h)
return image.crop((x, y, x + target_w, y + target_h)), [
[x, y],
[x + target_w, y + target_h],
]
def motion_blur_opencv(image, kernel_size=15, angle=0):
# 线性核
kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
kernel[kernel_size // 2, :] = np.ones(kernel_size, dtype=np.float32)
# 旋转核
center = (kernel_size / 2 - 0.5, kernel_size / 2 - 0.5)
M = cv2.getRotationMatrix2D(center, angle, 1)
rotated_kernel = cv2.warpAffine(kernel, M, (kernel_size, kernel_size))
# 归一化核
rotated_kernel /= rotated_kernel.sum() if rotated_kernel.sum() != 0 else 1
img = np.array(image)
if img.ndim == 2:
blurred = cv2.filter2D(img, -1, rotated_kernel, borderType=cv2.BORDER_REFLECT)
else:
# 对于彩色图像,各通道独立卷积
blurred = np.zeros_like(img)
for c in range(img.shape[2]):
blurred[..., c] = cv2.filter2D(
img[..., c], -1, rotated_kernel, borderType=cv2.BORDER_REFLECT
)
return Image.fromarray(blurred.astype(np.uint8))
def shuffle_patch(image, num_splits, gap_size=2):
"""将图像分割为块(允许尺寸不整除),随机打乱后拼接,块间保留间隙"""
h_splits, w_splits = num_splits
img_w, img_h = image.size
base_patch_h = img_h // h_splits
patch_heights = [base_patch_h] * (h_splits - 1)
patch_heights.append(img_h - sum(patch_heights))
base_patch_w = img_w // w_splits
patch_widths = [base_patch_w] * (w_splits - 1)
patch_widths.append(img_w - sum(patch_widths))
patches = []
current_y = 0
for i in range(h_splits):
current_x = 0
patch_h = patch_heights[i]
for j in range(w_splits):
patch_w = patch_widths[j]
patch = image.crop(
(current_x, current_y, current_x + patch_w, current_y + patch_h)
)
patches.append(patch)
current_x += patch_w
current_y += patch_h
random.shuffle(patches)
total_width = sum(patch_widths) + (w_splits - 1) * gap_size
total_height = sum(patch_heights) + (h_splits - 1) * gap_size
new_image = Image.new(
image.mode, (total_width, total_height), color=(255, 255, 255)
)
current_y = 0 # 当前行的起始 Y 坐标
patch_idx = 0 # 当前处理的块索引
for i in range(h_splits):
current_x = 0 # 当前列的起始 X 坐标
patch_h = patch_heights[i] # 当前行块的高度
for j in range(w_splits):
# 取出打乱后的块
patch = patches[patch_idx]
patch_w = patch_widths[j] # 当前列块的宽度
# 粘贴块(左上角坐标为 (current_x, current_y))
new_image.paste(patch, (current_x, current_y))
# 更新 X 坐标(下一个块的起始位置 = 当前块宽度 + 间隙)
current_x += patch_w + gap_size
patch_idx += 1
# 更新 Y 坐标(下一行的起始位置 = 当前行高度 + 间隙)
current_y += patch_h + gap_size
return new_image
def inpainting(image, num_splits, blank_ratio=0.3, blank_color=(255, 255, 255)):
"""
图像分割后随机空白部分patch,用于inpainting任务
参数:
image: PIL.Image 输入图像(RGB模式)
h_splits: int 行分割数(垂直方向分割块数)
w_splits: int 列分割数(水平方向分割块数)
blank_ratio: float 空白patch的比例(0~1)
blank_color: tuple 空白区域的颜色(RGB,如白色(255,255,255))
返回:
PIL.Image 处理后拼接的图像
"""
h_splits, w_splits = num_splits
img_w, img_h = image.size
base_patch_h = img_h // h_splits
patch_heights = [base_patch_h] * (h_splits - 1)
patch_heights.append(img_h - sum(patch_heights))
base_patch_w = img_w // w_splits
patch_widths = [base_patch_w] * (w_splits - 1)
patch_widths.append(img_w - sum(patch_widths))
patches = []
current_y = 0
for i in range(h_splits):
current_x = 0
patch_h = patch_heights[i]
for j in range(w_splits):
patch_w = patch_widths[j]
patch = image.crop(
(current_x, current_y, current_x + patch_w, current_y + patch_h)
)
patches.append(patch)
current_x += patch_w
current_y += patch_h
total_patches = h_splits * w_splits
num_blank = int(total_patches * blank_ratio)
num_blank = max(0, min(num_blank, total_patches))
blank_indices = random.sample(range(total_patches), num_blank)
processed_patches = []
for idx, patch in enumerate(patches):
if idx in blank_indices:
blank_patch = Image.new("RGB", patch.size, color=blank_color)
processed_patches.append(blank_patch)
else:
processed_patches.append(patch)
# 创建结果图像(尺寸与原图一致)
result_image = Image.new("RGB", (img_w, img_h))
current_y = 0
patch_idx = 0
for i in range(h_splits):
current_x = 0
patch_h = patch_heights[i]
for j in range(w_splits):
# 取出处理后的patch
patch = processed_patches[patch_idx]
patch_w = patch_widths[j]
# 粘贴到原位置
result_image.paste(patch, (current_x, current_y))
current_x += patch_w
patch_idx += 1
current_y += patch_h
return result_image
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
import torch
from PIL import Image
from .data.data_utils import pil_img2rgb
from .modeling.bagel.qwen2_navit import NaiveCache
VLM_THINK_SYSTEM_PROMPT = """You should first think about the reasoning process in the mind and then provide the user with the answer.
The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here"""
GEN_THINK_SYSTEM_PROMPT = """You should first think about the planning process in the mind and then generate the image.
The planning process is enclosed within <think> </think> tags, i.e. <think> planning process here </think> image here"""
class InterleaveInferencer:
def __init__(
self, model, vae_model, tokenizer, vae_transform, vit_transform, new_token_ids
):
self.model = model
self.vae_model = vae_model
self.tokenizer = tokenizer
self.vae_transform = vae_transform
self.vit_transform = vit_transform
self.new_token_ids = new_token_ids
def init_gen_context(self):
gen_context = {
"kv_lens": [0],
"ropes": [0],
"past_key_values": NaiveCache(
self.model.config.llm_config.num_hidden_layers
),
}
return gen_context
@torch.no_grad()
def update_context_text(self, text, gen_context):
# used for interleave data, currently only support 1 data inference,
past_key_values = gen_context["past_key_values"]
kv_lens = gen_context["kv_lens"]
ropes = gen_context["ropes"]
generation_input, kv_lens, ropes = self.model.prepare_prompts(
curr_kvlens=kv_lens,
curr_rope=ropes,
prompts=[text],
tokenizer=self.tokenizer,
new_token_ids=self.new_token_ids,
)
past_key_values = self.model.forward_cache_update_text(
past_key_values, **generation_input
)
gen_context["kv_lens"] = kv_lens
gen_context["ropes"] = ropes
gen_context["past_key_values"] = past_key_values
return gen_context
@torch.no_grad()
def update_context_image(self, image, gen_context, vae=True, vit=True):
# used for interleave data, currently only support 1 data inference,
assert vae or vit
past_key_values = gen_context["past_key_values"]
kv_lens = gen_context["kv_lens"]
ropes = gen_context["ropes"]
if vae:
## update vae
generation_input, kv_lens, ropes = self.model.prepare_vae_images(
curr_kvlens=kv_lens,
curr_rope=ropes,
images=[image],
transforms=self.vae_transform,
new_token_ids=self.new_token_ids,
)
past_key_values = self.model.forward_cache_update_vae(
self.vae_model, past_key_values, **generation_input
)
if vit:
## update vit
generation_input, kv_lens, ropes = self.model.prepare_vit_images(
curr_kvlens=kv_lens,
curr_rope=ropes,
images=[image],
transforms=self.vit_transform,
new_token_ids=self.new_token_ids,
)
past_key_values = self.model.forward_cache_update_vit(
past_key_values, **generation_input
)
gen_context["kv_lens"] = kv_lens
gen_context["ropes"] = ropes
gen_context["past_key_values"] = past_key_values
return gen_context
@torch.no_grad()
def gen_image(
self,
image_shape,
gen_context,
cfg_text_scale=4.0,
cfg_img_scale=1.5,
cfg_text_precontext=None,
cfg_img_precontext=None,
cfg_interval=(0.4, 1.0),
cfg_renorm_min=0.0,
cfg_renorm_type="global",
num_timesteps=50,
timestep_shift=3.0,
enable_taylorseer=False,
):
# print(cfg_renorm_type)
past_key_values = gen_context["past_key_values"]
kv_lens = gen_context["kv_lens"]
ropes = gen_context["ropes"]
generation_input = self.model.prepare_vae_latent(
curr_kvlens=kv_lens,
curr_rope=ropes,
image_sizes=[image_shape],
new_token_ids=self.new_token_ids,
)
# text cfg
cfg_text_past_key_values = cfg_text_precontext["past_key_values"]
kv_lens_cfg = cfg_text_precontext["kv_lens"]
ropes_cfg = cfg_text_precontext["ropes"]
generation_input_cfg_text = self.model.prepare_vae_latent_cfg(
curr_kvlens=kv_lens_cfg,
curr_rope=ropes_cfg,
image_sizes=[image_shape],
)
# img cfg
cfg_img_past_key_values = cfg_img_precontext["past_key_values"]
kv_lens_cfg = cfg_img_precontext["kv_lens"]
ropes_cfg = cfg_img_precontext["ropes"]
generation_input_cfg_img = self.model.prepare_vae_latent_cfg(
curr_kvlens=kv_lens_cfg,
curr_rope=ropes_cfg,
image_sizes=[image_shape],
)
unpacked_latent = self.model.generate_image(
past_key_values=past_key_values,
cfg_text_past_key_values=cfg_text_past_key_values,
cfg_img_past_key_values=cfg_img_past_key_values,
num_timesteps=num_timesteps,
cfg_text_scale=cfg_text_scale,
cfg_img_scale=cfg_img_scale,
cfg_interval=cfg_interval,
cfg_renorm_min=cfg_renorm_min,
cfg_renorm_type=cfg_renorm_type,
timestep_shift=timestep_shift,
**generation_input,
cfg_text_packed_position_ids=generation_input_cfg_text[
"cfg_packed_position_ids"
],
cfg_text_packed_query_indexes=generation_input_cfg_text[
"cfg_packed_query_indexes"
],
cfg_text_key_values_lens=generation_input_cfg_text["cfg_key_values_lens"],
cfg_text_packed_key_value_indexes=generation_input_cfg_text[
"cfg_packed_key_value_indexes"
],
cfg_img_packed_position_ids=generation_input_cfg_img[
"cfg_packed_position_ids"
],
cfg_img_packed_query_indexes=generation_input_cfg_img[
"cfg_packed_query_indexes"
],
cfg_img_key_values_lens=generation_input_cfg_img["cfg_key_values_lens"],
cfg_img_packed_key_value_indexes=generation_input_cfg_img[
"cfg_packed_key_value_indexes"
],
enable_taylorseer=enable_taylorseer,
)
image = self.decode_image(unpacked_latent[0], image_shape)
return image
def decode_image(self, latent, image_shape):
H, W = image_shape
h, w = H // self.model.latent_downsample, W // self.model.latent_downsample
latent = latent.reshape(
1,
h,
w,
self.model.latent_patch_size,
self.model.latent_patch_size,
self.model.latent_channel,
)
latent = torch.einsum("nhwpqc->nchpwq", latent)
latent = latent.reshape(
1,
self.model.latent_channel,
h * self.model.latent_patch_size,
w * self.model.latent_patch_size,
)
image = self.vae_model.decode(latent)
image = (image * 0.5 + 0.5).clamp(0, 1)[0].permute(1, 2, 0) * 255
image = Image.fromarray((image).to(torch.uint8).cpu().numpy())
return image
@torch.no_grad()
def gen_text(
self,
gen_context,
max_length: int = 500,
do_sample: bool = True,
temperature: float = 1.0,
):
gen_context = deepcopy(gen_context)
past_key_values = gen_context["past_key_values"]
kv_lens = gen_context["kv_lens"]
ropes = gen_context["ropes"]
generation_input = self.model.prepare_start_tokens(
kv_lens, ropes, self.new_token_ids
)
unpacked_latent = self.model.generate_text(
past_key_values=past_key_values,
max_length=max_length,
do_sample=do_sample,
temperature=temperature,
end_token_id=self.new_token_ids["eos_token_id"],
**generation_input,
)
output = self.tokenizer.decode(unpacked_latent[:, 0])
output = output.split("<|im_end|>")[0].split("<|im_start|>")[1]
return output
@torch.no_grad()
def interleave_inference(
self,
input_lists: List[Union[str, Image.Image]],
think=False,
understanding_output=False,
max_think_token_n=1000,
do_sample=False,
text_temperature=0.3,
cfg_text_scale=3.0,
cfg_img_scale=1.5,
cfg_interval=[0.4, 1.0],
timestep_shift=3.0,
num_timesteps=50,
cfg_renorm_min=0.0,
cfg_renorm_type="global",
image_shapes=(1024, 1024),
enable_taylorseer=False,
) -> List[Union[str, Image.Image]]:
output_list = []
gen_context = self.init_gen_context()
cfg_text_context = deepcopy(gen_context)
cfg_img_context = deepcopy(gen_context)
with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
if think:
if understanding_output:
system_prompt = VLM_THINK_SYSTEM_PROMPT
else:
system_prompt = GEN_THINK_SYSTEM_PROMPT
gen_context = self.update_context_text(system_prompt, gen_context)
cfg_img_context = self.update_context_text(
system_prompt, cfg_img_context
)
for input_term in input_lists:
if isinstance(input_term, str):
cfg_text_context = deepcopy(gen_context)
gen_context = self.update_context_text(input_term, gen_context)
cfg_img_context = self.update_context_text(
input_term, cfg_img_context
)
elif isinstance(input_term, Image.Image):
input_term = self.vae_transform.resize_transform(
pil_img2rgb(input_term)
)
gen_context = self.update_context_image(
input_term, gen_context, vae=not understanding_output
)
image_shapes = input_term.size[::-1]
cfg_text_context = deepcopy(gen_context)
else:
raise ValueError(f"Unsupported input type: {type(input_term)}")
if understanding_output:
gen_text = self.gen_text(
gen_context,
do_sample=do_sample,
temperature=text_temperature,
max_length=max_think_token_n,
)
output_list.append(gen_text)
else:
if think:
gen_text = self.gen_text(
gen_context,
do_sample=do_sample,
temperature=text_temperature,
max_length=max_think_token_n,
)
gen_context = self.update_context_text(gen_text, gen_context)
output_list.append(gen_text)
img = self.gen_image(
image_shapes,
gen_context,
cfg_text_precontext=cfg_text_context,
cfg_img_precontext=cfg_img_context,
cfg_text_scale=cfg_text_scale,
cfg_img_scale=cfg_img_scale,
cfg_interval=cfg_interval,
timestep_shift=timestep_shift,
num_timesteps=num_timesteps,
cfg_renorm_min=cfg_renorm_min,
cfg_renorm_type=cfg_renorm_type,
enable_taylorseer=enable_taylorseer,
)
output_list.append(img)
return output_list
def __call__(
self, image: Optional[Image.Image] = None, text: Optional[str] = None, **kargs
) -> Dict[str, Any]:
output_dict = {"image": None, "text": None}
if image is None and text is None:
print("Please provide at least one input: either an image or text.")
return output_dict
input_list = []
if image is not None:
input_list.append(image)
if text is not None:
input_list.append(text)
output_list = self.interleave_inference(input_list, **kargs)
for i in output_list:
if isinstance(i, Image.Image):
output_dict["image"] = i
elif isinstance(i, str):
output_dict["text"] = i
return output_dict
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
from . import autoencoder, bagel, qwen2, siglip
# Copyright (c) 2024 Black Forest Labs.
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under Apache-2.0, with the full license text
# available at https://github.com/black-forest-labs/flux/blob/main/LICENSE.
#
# This modified file is released under the same license.
from dataclasses import dataclass
import torch
from einops import rearrange
from safetensors.torch import load_file as load_sft
from torch import Tensor, nn
@dataclass
class AutoEncoderParams:
resolution: int
in_channels: int
downsample: int
ch: int
out_ch: int
ch_mult: list[int]
num_res_blocks: int
z_channels: int
scale_factor: float
shift_factor: float
def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)
class AttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
def attention(self, h_: Tensor) -> Tensor:
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
def forward(self, x: Tensor) -> Tensor:
return x + self.proj_out(self.attention(x))
class ResnetBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.norm2 = nn.GroupNorm(
num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
self.nin_shortcut = nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h = x
h = self.norm1(h)
h = swish(h)
h = self.conv1(h)
h = self.norm2(h)
h = swish(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Downsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
# no asymmetric padding in torch conv, must do it ourselves
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, x: Tensor):
pad = (0, 1, 0, 1)
x = nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x: Tensor):
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class Encoder(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1
)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
block_in = self.ch
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# end
self.norm_out = nn.GroupNorm(
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
)
self.conv_out = nn.Conv2d(
block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x: Tensor) -> Tensor:
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(
self,
ch: int,
out_ch: int,
ch_mult: list[int],
num_res_blocks: int,
in_channels: int,
resolution: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.ffactor = 2 ** (self.num_resolutions - 1)
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# z to block_in
self.conv_in = nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1
)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = nn.GroupNorm(
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z: Tensor) -> Tensor:
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
class DiagonalGaussian(nn.Module):
def __init__(self, sample: bool = True, chunk_dim: int = 1):
super().__init__()
self.sample = sample
self.chunk_dim = chunk_dim
def forward(self, z: Tensor) -> Tensor:
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
if self.sample:
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)
else:
return mean
class AutoEncoder(nn.Module):
def __init__(self, params: AutoEncoderParams):
super().__init__()
self.encoder = Encoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.decoder = Decoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
out_ch=params.out_ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.reg = DiagonalGaussian()
self.scale_factor = params.scale_factor
self.shift_factor = params.shift_factor
def encode(self, x: Tensor) -> Tensor:
z = self.reg(self.encoder(x))
z = self.scale_factor * (z - self.shift_factor)
return z
def decode(self, z: Tensor) -> Tensor:
z = z / self.scale_factor + self.shift_factor
return self.decoder(z)
def forward(self, x: Tensor) -> Tensor:
return self.decode(self.encode(x))
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
if len(missing) > 0 and len(unexpected) > 0:
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
print("\n" + "-" * 79 + "\n")
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
elif len(missing) > 0:
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
elif len(unexpected) > 0:
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
def load_ae(local_path: str) -> AutoEncoder:
ae_params = AutoEncoderParams(
resolution=256,
in_channels=3,
downsample=8,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
)
# Loading the autoencoder
ae = AutoEncoder(ae_params)
if local_path is not None:
sd = load_sft(local_path)
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
print_load_warning(missing, unexpected)
return ae, ae_params
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
from .bagel import Bagel, BagelConfig
from .qwen2_navit import Qwen2Config, Qwen2ForCausalLM, Qwen2Model
from .siglip_navit import SiglipVisionConfig, SiglipVisionModel
__all__ = [
"BagelConfig",
"Bagel",
"Qwen2Config",
"Qwen2Model",
"Qwen2ForCausalLM",
"SiglipVisionConfig",
"SiglipVisionModel",
]
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import copy
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
# from torch.nn.attention.flex_attention import create_block_mask
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from ...data.data_utils import (
create_sparse_mask,
get_flattened_position_ids_extrapolate,
get_flattened_position_ids_interpolate,
patchify,
)
from ..cache_utils.taylorseer import cache_init
from .modeling_utils import MLPconnector, PositionEmbedding, TimestepEmbedder
from .qwen2_navit import NaiveCache
class BagelConfig(PretrainedConfig):
def __init__(
self,
visual_gen=True,
visual_und=True,
llm_config=None,
vit_config=None,
vae_config=None,
latent_patch_size=2,
max_latent_size=32,
vit_max_num_patch_per_side=70,
connector_act="gelu_pytorch_tanh",
interpolate_pos=False,
timestep_shift=1.0,
**kwargs,
):
super().__init__(**kwargs)
self.visual_gen = visual_gen
self.visual_und = visual_und
self.llm_config = llm_config
self.vit_config = vit_config
self.vae_config = vae_config
self.latent_patch_size = latent_patch_size
self.max_latent_size = max_latent_size
self.vit_max_num_patch_per_side = vit_max_num_patch_per_side
self.connector_act = connector_act
self.interpolate_pos = interpolate_pos
self.timestep_shift = timestep_shift
class Bagel(PreTrainedModel):
config_class = BagelConfig
base_model_prefix = "bagel"
def __init__(self, language_model, vit_model, config: BagelConfig):
super().__init__(config)
self.language_model = language_model
self.hidden_size = config.llm_config.hidden_size
self.use_moe = "Mo" in config.llm_config.layer_module
self.num_heads = config.llm_config.num_attention_heads
if config.visual_gen:
self.latent_patch_size = config.latent_patch_size
self.timestep_shift = config.timestep_shift
self.latent_downsample = (
config.vae_config.downsample * config.latent_patch_size
)
self.max_latent_size = config.max_latent_size
self.latent_channel = config.vae_config.z_channels
self.patch_latent_dim = self.latent_patch_size**2 * self.latent_channel
self.time_embedder = TimestepEmbedder(self.hidden_size)
self.vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size)
self.llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim)
self.latent_pos_embed = PositionEmbedding(
self.max_latent_size, self.hidden_size
)
if config.visual_und:
self.vit_model = vit_model
self.vit_patch_size = config.vit_config.patch_size
self.vit_max_num_patch_per_side = config.vit_max_num_patch_per_side
self.vit_hidden_size = config.vit_config.hidden_size
self.connector = MLPconnector(
self.vit_hidden_size, self.hidden_size, config.connector_act
)
self.vit_pos_embed = PositionEmbedding(
self.vit_max_num_patch_per_side, self.hidden_size
)
if config.interpolate_pos:
self.get_flattened_position_ids = get_flattened_position_ids_interpolate
else:
self.get_flattened_position_ids = get_flattened_position_ids_extrapolate
self.config = config
self._init_weights()
def _init_weights(self):
if self.config.visual_gen:
nn.init.constant_(self.llm2vae.weight, 0)
nn.init.constant_(self.llm2vae.bias, 0)
def forward(
self,
sequence_length: int,
packed_text_ids: torch.LongTensor,
packed_text_indexes: torch.LongTensor,
sample_lens: List[int],
packed_position_ids: torch.LongTensor,
nested_attention_masks: List[torch.Tensor] = None,
split_lens: List[int] = None,
attn_modes: List[str] = None,
# for visual understanding
ce_loss_indexes: Optional[torch.BoolTensor] = None,
packed_label_ids: Optional[torch.LongTensor] = None,
packed_vit_tokens: Optional[torch.Tensor] = None,
packed_vit_token_indexes: Optional[torch.LongTensor] = None,
packed_vit_position_ids: Optional[torch.LongTensor] = None,
vit_token_seqlens: Optional[torch.IntTensor] = None,
# for visual generation
padded_latent: Optional[torch.Tensor] = None,
patchified_vae_latent_shapes: Optional[List[Tuple[int, int]]] = None,
packed_latent_position_ids: Optional[torch.LongTensor] = None,
packed_vae_token_indexes: Optional[torch.LongTensor] = None,
packed_timesteps: Optional[torch.LongTensor] = None,
mse_loss_indexes: Optional[torch.BoolTensor] = None,
) -> torch.Tensor:
"""
Args:
sequence_length: length of sequence.
packed_text_ids: 1-D int tensor, packed text token ids.
packed_text_indexes: 1-D int tensor, packed text token indexes in sequence.
sample_lens: A list of N ints, length of each sample in packed_sequence.
nested_attention_masks: A list of N 2-D float tensor, where 0.0 means attention and
-inf means ignore.
packed_position_ids: packed 1-D positions, an image has only one global position shared
by all latent tokens.
packed_vit_tokens: packed patchified image tokens for vit model.
packed_vit_position_ids: 1-D int tensor, the position of each token for vit model.
packed_vit_token_indexes: 1-D int tensor, packed vit token indexes in sequence.
vit_token_seqlens: 1-D int tensor, the length of each image tokens for vit model.
packed_label_ids: 1-D int tensor, packed label token ids.
ce_loss_indexes: 1-D bool tensor, where to compute ce loss.
padded_latent: padded latent from VAE encoder.
patchified_vae_latent_shapes: A list of (h, w) tuples, patchfied latent shapes of each image.
packed_latent_position_ids: 1-D int tensor, the position of each token for latent.
packed_vae_token_indexes: 1-D int tensor, padded image token indexes in sequence.
packed_timesteps: 1-D float tensor, flow timesteps. 0 indicates use clean image.
mse_loss_indexes: 1-D bool tensor, where to compute mse loss.
"""
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
packed_sequence = packed_text_embedding.new_zeros(
size=(sequence_length, self.hidden_size)
)
packed_sequence[packed_text_indexes] = packed_text_embedding
if nested_attention_masks is None:
sparse_mask = create_sparse_mask(
sample_lens, split_lens, attn_modes, packed_text_embedding.device
)
seqlen = sum(sample_lens)
block_mask = create_block_mask(
sparse_mask,
B=1,
H=self.num_heads,
Q_LEN=seqlen,
KV_LEN=seqlen,
device=packed_text_embedding.device,
BLOCK_SIZE=128,
_compile=True,
)
attention_mask = block_mask
else:
attention_mask = nested_attention_masks
if self.config.visual_und:
cu_seqlens = torch.nn.functional.pad(
torch.cumsum(vit_token_seqlens, dim=0), (1, 0)
)
cu_seqlens = cu_seqlens.to(torch.int32)
max_seqlen = torch.max(vit_token_seqlens).item()
packed_vit_token_embed = self.vit_model(
packed_pixel_values=packed_vit_tokens,
packed_flattened_position_ids=packed_vit_position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
packed_vit_token_embed = self.connector(packed_vit_token_embed)
vit_token_pos_emb = self.vit_pos_embed(packed_vit_position_ids)
packed_vit_token_embed = packed_vit_token_embed + vit_token_pos_emb
packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed
if self.config.visual_gen:
p = self.latent_patch_size
packed_latent = []
for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes):
latent = latent[:, : h * p, : w * p].reshape(
self.latent_channel, h, p, w, p
)
latent = torch.einsum("chpwq->hwpqc", latent).reshape(
-1, p * p * self.latent_channel
)
packed_latent.append(latent)
packed_latent_clean = torch.cat(packed_latent, dim=0)
noise = torch.randn_like(packed_latent_clean)
packed_timesteps = torch.sigmoid(packed_timesteps)
packed_timesteps = (
self.timestep_shift
* packed_timesteps
/ (1 + (self.timestep_shift - 1) * packed_timesteps)
)
packed_latent = (
1 - packed_timesteps[:, None]
) * packed_latent_clean + packed_timesteps[:, None] * noise
packed_timestep_embeds = self.time_embedder(packed_timesteps)
latent_token_pos_emb = self.latent_pos_embed(packed_latent_position_ids)
packed_latent = (
self.vae2llm(packed_latent)
+ packed_timestep_embeds
+ latent_token_pos_emb
)
packed_sequence[packed_vae_token_indexes] = packed_latent
extra_inputs = {}
if self.use_moe:
packed_und_token_indexes = packed_text_indexes
if packed_vit_token_indexes is not None:
packed_und_token_indexes = torch.cat(
[packed_text_indexes, packed_vit_token_indexes], dim=0
)
extra_inputs.update(
packed_und_token_indexes=packed_und_token_indexes,
packed_gen_token_indexes=packed_vae_token_indexes,
)
last_hidden_state = self.language_model(
packed_sequence=packed_sequence,
sample_lens=sample_lens,
attention_mask=attention_mask,
packed_position_ids=packed_position_ids,
**extra_inputs,
)
mse = None
if self.config.visual_gen:
packed_mse_preds = self.llm2vae(last_hidden_state[mse_loss_indexes])
target = (
noise - packed_latent_clean
) # NOTE: v_t=dx_t/dt=x_1-x_0, pointing from data to noise
has_mse = packed_timesteps > 0
mse = (packed_mse_preds - target[has_mse]) ** 2
ce = None
if ce_loss_indexes is not None:
packed_ce_preds = self.language_model.lm_head(
last_hidden_state[ce_loss_indexes]
)
ce = F.cross_entropy(packed_ce_preds, packed_label_ids, reduction="none")
return dict(mse=mse, ce=ce)
def prepare_prompts(
self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_ids
):
packed_text_ids = list()
packed_text_position_ids = list()
text_token_lens = list()
packed_text_indexes = list()
packed_key_value_indexes = list()
curr = 0
newlens, new_rope = list(), list()
for prompt, curr_kvlen, curr_position_id in zip(
prompts, curr_kvlens, curr_rope
):
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
curr += curr_kvlen
text_ids = tokenizer.encode(prompt)
text_ids = (
[new_token_ids["bos_token_id"]]
+ text_ids
+ [new_token_ids["eos_token_id"]]
)
text_token_lens.append(len(text_ids))
packed_text_ids.extend(text_ids)
packed_text_position_ids.extend(
range(curr_position_id, curr_position_id + len(text_ids))
)
packed_text_indexes.extend(range(curr, curr + len(text_ids)))
newlens.append(curr_kvlen + len(text_ids))
new_rope.append(curr_position_id + len(text_ids))
curr += len(text_ids)
generation_input = {
"text_token_lens": torch.tensor(text_token_lens, dtype=torch.int),
"packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
"packed_text_position_ids": torch.tensor(
packed_text_position_ids, dtype=torch.long
),
"packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
"packed_key_value_indexes": torch.tensor(
packed_key_value_indexes, dtype=torch.long
),
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
}
return generation_input, newlens, new_rope
@torch.no_grad
def forward_cache_update_text(
self,
past_key_values: NaiveCache,
packed_text_ids: torch.IntTensor,
packed_text_position_ids: torch.LongTensor,
text_token_lens: torch.LongTensor,
packed_text_indexes: torch.LongTensor,
packed_key_value_indexes: torch.LongTensor,
key_values_lens: torch.IntTensor,
):
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
extra_inputs = {}
if self.use_moe:
extra_inputs = {"mode": "und"}
output = self.language_model.forward_inference(
packed_query_sequence=packed_text_embedding,
query_lens=text_token_lens,
packed_query_position_ids=packed_text_position_ids,
packed_query_indexes=packed_text_indexes,
past_key_values=past_key_values,
packed_key_value_indexes=packed_key_value_indexes,
key_values_lens=key_values_lens,
update_past_key_values=True,
is_causal=True,
**extra_inputs,
)
past_key_values = output.past_key_values
return past_key_values
def prepare_vit_images(
self, curr_kvlens, curr_rope, images, transforms, new_token_ids
):
packed_vit_token_indexes = list()
vit_token_seqlens, packed_vit_tokens, packed_vit_position_ids = (
list(),
list(),
list(),
)
packed_text_ids, packed_text_indexes = list(), list()
packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
packed_key_value_indexes = list()
_curr = curr = 0
newlens, new_rope = list(), list()
for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
curr += curr_kvlen
packed_text_ids.append(new_token_ids["start_of_image"])
packed_text_indexes.append(_curr)
packed_indexes.append(curr)
curr += 1
_curr += 1
image_tensor = transforms(image)
vit_position_ids = self.get_flattened_position_ids(
image_tensor.size(1),
image_tensor.size(2),
self.vit_patch_size,
max_num_patches_per_side=self.vit_max_num_patch_per_side,
)
vit_tokens = patchify(image_tensor, self.vit_patch_size)
packed_vit_tokens.append(vit_tokens)
num_img_tokens = vit_tokens.shape[0]
packed_vit_position_ids.append(vit_position_ids)
vit_token_seqlens.append(num_img_tokens)
packed_vit_token_indexes.extend(range(_curr, _curr + num_img_tokens))
packed_indexes.extend(range(curr, curr + num_img_tokens))
curr += num_img_tokens
_curr += num_img_tokens
packed_text_ids.append(new_token_ids["end_of_image"])
packed_text_indexes.append(_curr)
packed_indexes.append(curr)
curr += 1
_curr += 1
packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
packed_seqlens.append(num_img_tokens + 2)
newlens.append(curr_kvlen + num_img_tokens + 2)
new_rope.append(curr_position_id + 1)
generation_input = {
"packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
"packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
"vit_token_seqlens": torch.tensor(vit_token_seqlens, dtype=torch.int),
"packed_vit_tokens": torch.cat(packed_vit_tokens, dim=0),
"packed_vit_position_ids": torch.cat(packed_vit_position_ids, dim=0),
"packed_vit_token_indexes": torch.tensor(
packed_vit_token_indexes, dtype=torch.long
),
"packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
"packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
"packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
"packed_key_value_indexes": torch.tensor(
packed_key_value_indexes, dtype=torch.long
),
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
}
return generation_input, newlens, new_rope
@torch.no_grad
def forward_cache_update_vit(
self,
past_key_values: NaiveCache,
packed_text_ids: torch.LongTensor,
packed_text_indexes: torch.LongTensor,
packed_vit_tokens: torch.Tensor,
packed_vit_token_indexes: torch.LongTensor,
packed_vit_position_ids: torch.LongTensor,
vit_token_seqlens: torch.IntTensor,
packed_position_ids: torch.LongTensor,
packed_seqlens: torch.IntTensor,
packed_indexes: torch.LongTensor,
packed_key_value_indexes: torch.LongTensor,
key_values_lens: torch.IntTensor,
):
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
packed_sequence = packed_text_embedding.new_zeros(
(sum(packed_seqlens), self.hidden_size)
)
packed_sequence[packed_text_indexes] = packed_text_embedding
cu_seqlens = torch.nn.functional.pad(
torch.cumsum(vit_token_seqlens, dim=0), (1, 0)
)
cu_seqlens = cu_seqlens.to(torch.int32)
max_seqlen = torch.max(vit_token_seqlens).item()
packed_vit_token_embed = self.vit_model(
packed_pixel_values=packed_vit_tokens,
packed_flattened_position_ids=packed_vit_position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
packed_vit_token_embed = self.connector(packed_vit_token_embed)
pos_emb = self.vit_pos_embed(packed_vit_position_ids)
packed_vit_token_embed = packed_vit_token_embed + pos_emb
if packed_vit_token_embed.dtype != packed_sequence.dtype:
packed_vit_token_embed = packed_vit_token_embed.to(packed_sequence.dtype)
packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed
extra_inputs = {}
if self.use_moe:
extra_inputs = {"mode": "und"}
output = self.language_model.forward_inference(
packed_query_sequence=packed_sequence,
query_lens=packed_seqlens,
packed_query_position_ids=packed_position_ids,
packed_query_indexes=packed_indexes,
past_key_values=past_key_values,
packed_key_value_indexes=packed_key_value_indexes,
key_values_lens=key_values_lens,
update_past_key_values=True,
is_causal=False,
**extra_inputs,
)
past_key_values = output.past_key_values
return past_key_values
def prepare_vae_images(
self, curr_kvlens, curr_rope, images, transforms, new_token_ids, timestep=0
):
patchified_vae_latent_shapes, packed_vae_position_ids = list(), list()
packed_vae_token_indexes = list()
packed_text_ids, packed_text_indexes = list(), list()
packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
packed_key_value_indexes = list()
_curr = curr = 0
vae_image_tensors = list()
newlens, new_rope = list(), list()
for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
curr += curr_kvlen
packed_text_ids.append(new_token_ids["start_of_image"])
packed_text_indexes.append(_curr)
packed_indexes.append(curr)
curr += 1
_curr += 1
image_tensor = transforms(image)
vae_image_tensors.append(image_tensor)
vae_posiiton_ids = self.get_flattened_position_ids(
image_tensor.size(1),
image_tensor.size(2),
self.latent_downsample,
max_num_patches_per_side=self.max_latent_size,
)
packed_vae_position_ids.append(vae_posiiton_ids)
H, W = image_tensor.shape[1:]
h = H // self.latent_downsample
w = W // self.latent_downsample
patchified_vae_latent_shapes.append((h, w))
num_img_tokens = w * h
packed_vae_token_indexes.extend(range(_curr, _curr + num_img_tokens))
packed_indexes.extend(range(curr, curr + num_img_tokens))
curr += num_img_tokens
_curr += num_img_tokens
packed_text_ids.append(new_token_ids["end_of_image"])
packed_text_indexes.append(_curr)
packed_indexes.append(curr)
curr += 1
_curr += 1
packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
packed_seqlens.append(num_img_tokens + 2)
newlens.append(curr_kvlen + num_img_tokens + 2)
new_rope.append(curr_position_id + 1)
image_sizes = [item.shape for item in vae_image_tensors]
max_image_size = [max(item) for item in list(zip(*image_sizes))]
padded_images = torch.zeros(size=(len(vae_image_tensors), *max_image_size))
for i, image_tensor in enumerate(vae_image_tensors):
padded_images[i, :, : image_tensor.shape[1], : image_tensor.shape[2]] = (
image_tensor
)
generation_input = {
"padded_images": padded_images,
"patchified_vae_latent_shapes": patchified_vae_latent_shapes,
"packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0),
"packed_timesteps": torch.tensor([timestep]),
"packed_vae_token_indexes": torch.tensor(
packed_vae_token_indexes, dtype=torch.long
),
"packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
"packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
"packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
"packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
"packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
"packed_key_value_indexes": torch.tensor(
packed_key_value_indexes, dtype=torch.long
),
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
}
return generation_input, newlens, new_rope
@torch.no_grad
def forward_cache_update_vae(
self,
vae_model,
past_key_values: NaiveCache,
padded_images: torch.Tensor,
patchified_vae_latent_shapes: List,
packed_vae_position_ids: torch.LongTensor,
packed_timesteps: torch.Tensor,
packed_vae_token_indexes: torch.LongTensor,
packed_text_ids: torch.LongTensor,
packed_text_indexes: torch.LongTensor,
packed_position_ids: torch.LongTensor,
packed_seqlens: torch.IntTensor,
packed_indexes: torch.LongTensor,
key_values_lens: torch.IntTensor,
packed_key_value_indexes: torch.Tensor,
):
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
packed_sequence = packed_text_embedding.new_zeros(
(sum(packed_seqlens), self.hidden_size)
)
packed_sequence[packed_text_indexes] = packed_text_embedding
padded_latent = vae_model.encode(padded_images)
p = self.latent_patch_size
packed_latent = list()
for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes):
latent = latent[:, : h * p, : w * p].reshape(
self.latent_channel, h, p, w, p
)
latent = torch.einsum("chpwq->hwpqc", latent).reshape(
-1, p * p * self.latent_channel
)
packed_latent.append(latent)
packed_latent = torch.cat(packed_latent, dim=0)
packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
packed_timestep_embeds = self.time_embedder(packed_timesteps)
packed_latent = (
self.vae2llm(packed_latent) + packed_timestep_embeds + packed_pos_embed
)
if packed_latent.dtype != packed_sequence.dtype:
packed_latent = packed_latent.to(packed_sequence.dtype)
packed_sequence[packed_vae_token_indexes] = packed_latent
extra_inputs = {}
if self.use_moe:
extra_inputs = {
"mode": "gen",
"packed_vae_token_indexes": packed_vae_token_indexes,
"packed_text_indexes": packed_text_indexes,
}
output = self.language_model.forward_inference(
packed_query_sequence=packed_sequence,
query_lens=packed_seqlens,
packed_query_position_ids=packed_position_ids,
packed_query_indexes=packed_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=True,
is_causal=False,
**extra_inputs,
)
past_key_values = output.past_key_values
return past_key_values
def prepare_vae_latent(self, curr_kvlens, curr_rope, image_sizes, new_token_ids):
packed_text_ids, packed_text_indexes = list(), list()
packed_vae_position_ids, packed_vae_token_indexes, packed_init_noises = (
list(),
list(),
list(),
)
packed_position_ids, packed_seqlens, packed_indexes = list(), list(), list()
packed_key_value_indexes = list()
query_curr = curr = 0
for (H, W), curr_kvlen, curr_position_id in zip(
image_sizes, curr_kvlens, curr_rope
):
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
curr += curr_kvlen
packed_text_ids.append(new_token_ids["start_of_image"])
packed_text_indexes.append(query_curr)
packed_indexes.append(curr)
curr += 1
query_curr += 1
vae_posiiton_ids = self.get_flattened_position_ids(
H,
W,
self.latent_downsample,
max_num_patches_per_side=self.max_latent_size,
)
packed_vae_position_ids.append(vae_posiiton_ids)
h, w = H // self.latent_downsample, W // self.latent_downsample
num_image_tokens = h * w
packed_init_noises.append(
torch.randn(
num_image_tokens, self.latent_channel * self.latent_patch_size**2
)
)
packed_vae_token_indexes.extend(
range(query_curr, query_curr + num_image_tokens)
)
packed_indexes.extend(range(curr, curr + num_image_tokens))
curr += num_image_tokens
query_curr += num_image_tokens
packed_text_ids.append(new_token_ids["end_of_image"])
packed_text_indexes.append(query_curr)
packed_indexes.append(curr)
curr += 1
query_curr += 1
packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2))
packed_seqlens.append(num_image_tokens + 2)
generation_input = {
"packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
"packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
"packed_init_noises": torch.cat(packed_init_noises, dim=0),
"packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0),
"packed_vae_token_indexes": torch.tensor(
packed_vae_token_indexes, dtype=torch.long
),
"packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
"packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
"packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
"packed_key_value_indexes": torch.tensor(
packed_key_value_indexes, dtype=torch.long
),
}
return generation_input
def prepare_vae_latent_cfg(self, curr_kvlens, curr_rope, image_sizes):
packed_position_ids, packed_indexes, packed_key_value_indexes = (
list(),
list(),
list(),
)
query_curr = curr = 0
for (H, W), curr_kvlen, curr_position_id in zip(
image_sizes, curr_kvlens, curr_rope
):
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
curr += curr_kvlen
packed_indexes.append(curr)
curr += 1
query_curr += 1
h, w = H // self.latent_downsample, W // self.latent_downsample
num_image_tokens = h * w
packed_indexes.extend(range(curr, curr + num_image_tokens))
curr += num_image_tokens
query_curr += num_image_tokens
packed_indexes.append(curr)
curr += 1
query_curr += 1
packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2))
generation_input = {
"cfg_packed_position_ids": torch.tensor(
packed_position_ids, dtype=torch.long
),
"cfg_key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
"cfg_packed_query_indexes": torch.tensor(packed_indexes, dtype=torch.long),
"cfg_packed_key_value_indexes": torch.tensor(
packed_key_value_indexes, dtype=torch.long
),
}
return generation_input
@torch.no_grad
def generate_image(
self,
packed_text_ids: torch.LongTensor,
packed_text_indexes: torch.LongTensor,
packed_init_noises: torch.Tensor,
packed_vae_position_ids: torch.LongTensor,
packed_vae_token_indexes: torch.LongTensor,
packed_seqlens: torch.IntTensor,
packed_position_ids: torch.LongTensor,
packed_indexes: torch.LongTensor,
past_key_values: NaiveCache,
key_values_lens: torch.IntTensor,
packed_key_value_indexes: torch.LongTensor,
num_timesteps: int = 24,
timestep_shift: float = 1.0,
cfg_renorm_min: float = 0.0,
cfg_renorm_type: str = "global",
cfg_interval: Optional[Tuple[float, float]] = [0, 1],
# cfg_text
cfg_text_scale: float = 1.0,
cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None,
cfg_text_packed_position_ids: Optional[torch.LongTensor] = None,
cfg_text_past_key_values: Optional[NaiveCache] = None,
cfg_text_key_values_lens: Optional[torch.IntTensor] = None,
cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None,
# cfg_img
cfg_img_scale: float = 1.0,
cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None,
cfg_img_packed_position_ids: Optional[torch.LongTensor] = None,
cfg_img_past_key_values: Optional[NaiveCache] = None,
cfg_img_key_values_lens: Optional[torch.IntTensor] = None,
cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None,
cfg_type: str = "parallel",
# cache_args
enable_taylorseer=False,
):
if enable_taylorseer:
self.language_model.model.enable_taylorseer = True
model_pred_cache_dic, model_pred_current = cache_init(self, num_timesteps)
model_pred_text_cache_dic, model_pred_text_current = cache_init(
self, num_timesteps
)
model_pred_img_cache_dic, model_pred_img_current = cache_init(
self, num_timesteps
)
else:
self.language_model.model.enable_taylorseer = False
model_pred_cache_dic, model_pred_current = None, None
model_pred_text_cache_dic, model_pred_text_current = None, None
model_pred_img_cache_dic, model_pred_img_current = None, None
x_t = packed_init_noises
timesteps = torch.linspace(1, 0, num_timesteps, device=x_t.device)
timesteps = timestep_shift * timesteps / (1 + (timestep_shift - 1) * timesteps)
dts = timesteps[:-1] - timesteps[1:]
timesteps = timesteps[:-1]
for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):
timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device)
if t > cfg_interval[0] and t <= cfg_interval[1]:
cfg_text_scale_ = cfg_text_scale
cfg_img_scale_ = cfg_img_scale
else:
cfg_text_scale_ = 1.0
cfg_img_scale_ = 1.0
v_t = self._forward_flow(
x_t=x_t,
timestep=timestep,
packed_vae_token_indexes=packed_vae_token_indexes,
packed_vae_position_ids=packed_vae_position_ids,
packed_text_ids=packed_text_ids,
packed_text_indexes=packed_text_indexes,
packed_position_ids=packed_position_ids,
packed_indexes=packed_indexes,
packed_seqlens=packed_seqlens,
key_values_lens=key_values_lens,
past_key_values=past_key_values,
packed_key_value_indexes=packed_key_value_indexes,
cfg_renorm_min=cfg_renorm_min,
cfg_renorm_type=cfg_renorm_type,
# cfg_text
cfg_text_scale=cfg_text_scale_,
cfg_text_packed_position_ids=cfg_text_packed_position_ids,
cfg_text_packed_query_indexes=cfg_text_packed_query_indexes,
cfg_text_key_values_lens=cfg_text_key_values_lens,
cfg_text_past_key_values=cfg_text_past_key_values,
cfg_text_packed_key_value_indexes=cfg_text_packed_key_value_indexes,
# cfg_img
cfg_img_scale=cfg_img_scale_,
cfg_img_packed_position_ids=cfg_img_packed_position_ids,
cfg_img_packed_query_indexes=cfg_img_packed_query_indexes,
cfg_img_key_values_lens=cfg_img_key_values_lens,
cfg_img_past_key_values=cfg_img_past_key_values,
cfg_img_packed_key_value_indexes=cfg_img_packed_key_value_indexes,
cfg_type=cfg_type,
# cache
model_pred_cache_dic=model_pred_cache_dic,
model_pred_current=model_pred_current,
model_pred_text_cache_dic=model_pred_text_cache_dic,
model_pred_text_current=model_pred_text_current,
model_pred_img_cache_dic=model_pred_img_cache_dic,
model_pred_img_current=model_pred_img_current,
)
x_t = (
x_t - v_t.to(x_t.device) * dts[i]
) # velocity pointing from data to noise
if enable_taylorseer:
del model_pred_cache_dic, model_pred_current
del model_pred_text_cache_dic, model_pred_text_current
del model_pred_img_cache_dic, model_pred_img_current
unpacked_latent = x_t.split((packed_seqlens - 2).tolist())
return unpacked_latent
@torch.no_grad
def _forward_flow(
self,
x_t: torch.Tensor,
timestep: torch.LongTensor,
packed_vae_token_indexes: torch.LongTensor,
packed_vae_position_ids: torch.LongTensor,
packed_text_ids: torch.LongTensor,
packed_text_indexes: torch.LongTensor,
packed_indexes: torch.LongTensor,
packed_position_ids: torch.LongTensor,
packed_seqlens: torch.IntTensor,
key_values_lens: torch.IntTensor,
past_key_values: NaiveCache,
packed_key_value_indexes: torch.LongTensor,
cfg_renorm_min: float = 0.0,
cfg_renorm_type: str = "global",
# cfg_text
cfg_text_scale: float = 1.0,
cfg_text_packed_position_ids: Optional[torch.LongTensor] = None,
cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None,
cfg_text_key_values_lens: Optional[torch.Tensor] = None,
cfg_text_past_key_values: Optional[NaiveCache] = None,
cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None,
# cfg_img
cfg_img_scale: float = 1.0,
cfg_img_packed_position_ids: Optional[torch.LongTensor] = None,
cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None,
cfg_img_key_values_lens: Optional[torch.Tensor] = None,
cfg_img_past_key_values: Optional[NaiveCache] = None,
cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None,
cfg_type: str = "parallel",
# cache
model_pred_cache_dic: Optional[Dict[str, Any]] = None,
model_pred_current: Optional[int] = None,
model_pred_text_cache_dic: Optional[Dict[str, Any]] = None,
model_pred_text_current: Optional[int] = None,
model_pred_img_cache_dic: Optional[Dict[str, Any]] = None,
model_pred_img_current: Optional[int] = None,
):
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
packed_sequence = packed_text_embedding.new_zeros(
(sum(packed_seqlens), self.hidden_size)
)
packed_sequence[packed_text_indexes] = packed_text_embedding
assert timestep.unique().shape[0] == 1
packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
packed_timestep_embeds = self.time_embedder(timestep)
x_t = self.vae2llm(x_t) + packed_timestep_embeds + packed_pos_embed
if x_t.dtype != packed_sequence.dtype:
x_t = x_t.to(packed_sequence.dtype)
packed_sequence[packed_vae_token_indexes] = x_t
extra_inputs = {}
if self.use_moe:
extra_inputs = {
"mode": "gen",
"packed_vae_token_indexes": packed_vae_token_indexes,
"packed_text_indexes": packed_text_indexes,
}
if self.language_model.model.enable_taylorseer:
self.language_model.model.cache_dic = model_pred_cache_dic
self.language_model.model.current = model_pred_current
output = self.language_model.forward_inference(
packed_query_sequence=packed_sequence,
query_lens=packed_seqlens,
packed_query_position_ids=packed_position_ids,
packed_query_indexes=packed_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=False,
is_causal=False,
**extra_inputs,
)
v_t = self.llm2vae(output.packed_query_sequence)
v_t = v_t[packed_vae_token_indexes]
if cfg_text_scale > 1.0:
if self.language_model.model.enable_taylorseer:
self.language_model.model.cache_dic = model_pred_text_cache_dic
self.language_model.model.current = model_pred_text_current
cfg_text_output = self.language_model.forward_inference(
packed_query_sequence=packed_sequence,
query_lens=packed_seqlens,
packed_query_position_ids=cfg_text_packed_position_ids,
packed_query_indexes=cfg_text_packed_query_indexes,
past_key_values=cfg_text_past_key_values,
key_values_lens=cfg_text_key_values_lens,
packed_key_value_indexes=cfg_text_packed_key_value_indexes,
update_past_key_values=False,
is_causal=False,
**extra_inputs,
)
cfg_text_v_t = self.llm2vae(cfg_text_output.packed_query_sequence)
cfg_text_v_t = cfg_text_v_t[packed_vae_token_indexes]
if cfg_img_scale > 1.0:
if self.language_model.model.enable_taylorseer:
self.language_model.model.cache_dic = model_pred_img_cache_dic
self.language_model.model.current = model_pred_img_current
cfg_img_output = self.language_model.forward_inference(
packed_query_sequence=packed_sequence,
query_lens=packed_seqlens,
packed_query_position_ids=cfg_img_packed_position_ids,
packed_query_indexes=cfg_img_packed_query_indexes,
past_key_values=cfg_img_past_key_values,
key_values_lens=cfg_img_key_values_lens,
packed_key_value_indexes=cfg_img_packed_key_value_indexes,
update_past_key_values=False,
is_causal=False,
**extra_inputs,
)
cfg_img_v_t = self.llm2vae(cfg_img_output.packed_query_sequence)
cfg_img_v_t = cfg_img_v_t[packed_vae_token_indexes]
if cfg_text_scale > 1.0:
if cfg_renorm_type == "text_channel":
v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
norm_v_t_text_ = torch.norm(v_t_text_, dim=-1, keepdim=True)
scale = (norm_v_t / (norm_v_t_text_ + 1e-8)).clamp(
min=cfg_renorm_min, max=1.0
)
v_t_text = v_t_text_ * scale
if cfg_img_scale > 1.0:
v_t = cfg_img_v_t + cfg_img_scale * (v_t_text - cfg_img_v_t)
else:
v_t = v_t_text
else:
v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
if cfg_img_scale > 1.0:
v_t_ = cfg_img_v_t + cfg_img_scale * (v_t_text_ - cfg_img_v_t)
else:
v_t_ = v_t_text_
# NOTE norm is computed over all dimensions, thus currently only supports batch_size = 1 with navit
if cfg_renorm_type == "global":
norm_v_t = torch.norm(v_t)
norm_v_t_ = torch.norm(v_t_)
elif cfg_renorm_type == "channel":
norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
norm_v_t_ = torch.norm(v_t_, dim=-1, keepdim=True)
else:
raise NotImplementedError(f"{cfg_renorm_type} is not suppoprted")
scale = (norm_v_t / (norm_v_t_ + 1e-8)).clamp(
min=cfg_renorm_min, max=1.0
)
v_t = v_t_ * scale
else:
# No CFG
pass
return v_t
def prepare_start_tokens(self, curr_kvlens, curr_rope, new_token_ids):
packed_start_tokens, packed_key_value_indexes = list(), list()
packed_query_position_ids = list()
curr = 0
for curr_kvlen, curr_position_id in zip(curr_kvlens, curr_rope):
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
packed_start_tokens.append(new_token_ids["bos_token_id"])
packed_query_position_ids.append(curr_position_id)
curr += curr_kvlen
generation_input = {
"packed_start_tokens": torch.tensor(packed_start_tokens, dtype=torch.long),
"packed_query_position_ids": torch.tensor(
packed_query_position_ids, dtype=torch.long
),
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
"packed_key_value_indexes": torch.tensor(
packed_key_value_indexes, dtype=torch.long
),
}
return generation_input
@torch.no_grad
def generate_text(
self,
past_key_values: NaiveCache,
packed_key_value_indexes: torch.LongTensor,
key_values_lens: torch.IntTensor,
packed_start_tokens: torch.LongTensor,
packed_query_position_ids: torch.LongTensor,
max_length: int,
do_sample: bool = False,
temperature: float = 1.0,
end_token_id: int = None,
):
step = 0
generated_sequence = []
curr_tokens = packed_start_tokens
while step < max_length:
generated_sequence.append(curr_tokens)
packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens)
query_lens = torch.ones_like(curr_tokens)
packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange(
0,
len(key_values_lens),
device=key_values_lens.device,
dtype=key_values_lens.dtype,
)
uppacked = list(
packed_key_value_indexes.split(key_values_lens.tolist(), dim=0)
)
for i in range(len(uppacked)):
uppacked[i] += i
packed_key_value_indexes = torch.cat(uppacked, dim=0)
extra_inputs = {}
if self.use_moe:
extra_inputs = {"mode": "und"}
output = self.language_model.forward_inference(
packed_query_sequence=packed_text_embedding,
query_lens=query_lens,
packed_query_position_ids=packed_query_position_ids,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=True,
is_causal=True,
**extra_inputs,
)
past_key_values = output.past_key_values
packed_query_sequence = output.packed_query_sequence
pred_logits = self.language_model.lm_head(packed_query_sequence)
if do_sample:
probs = nn.functional.softmax(pred_logits / temperature, dim=-1)
curr_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
curr_tokens = torch.argmax(pred_logits, dim=-1)
uppacked = list(
packed_key_value_indexes.split(key_values_lens.tolist(), dim=0)
)
for i in range(len(uppacked)):
uppacked[i] = torch.cat(
[
uppacked[i],
torch.tensor([uppacked[i][-1] + 1], device=uppacked[i].device),
],
dim=0,
)
packed_key_value_indexes = torch.cat(uppacked, dim=0)
key_values_lens = key_values_lens + 1
packed_query_position_ids = packed_query_position_ids + 1
step += 1
if (
end_token_id is not None and curr_tokens[0] == end_token_id
): # only support batch=1
break
output_device = generated_sequence[0].device
return torch.stack([i.to(output_device) for i in generated_sequence], dim=0)
# for evaluation
@torch.no_grad()
def chat(
self,
tokenizer,
new_token_ids,
image_transform,
images,
prompt,
max_length: int,
do_sample: bool = False,
temperature: float = 1.0,
):
device = next(self.parameters()).device
if isinstance(new_token_ids, dict):
for k, v in new_token_ids.items():
if torch.is_tensor(v):
new_token_ids[k] = v.to(device)
elif torch.is_tensor(new_token_ids):
new_token_ids = new_token_ids.to(device)
# prefill
past_key_values = NaiveCache(self.config.llm_config.num_hidden_layers)
newlens = [0]
new_rope = [0]
# add images
for image in images:
generation_input, newlens, new_rope = self.prepare_vit_images(
curr_kvlens=newlens,
curr_rope=new_rope,
images=[image],
transforms=image_transform,
new_token_ids=new_token_ids,
)
for k, v in generation_input.items():
if torch.is_tensor(v):
generation_input[k] = v.to(device)
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
past_key_values = self.forward_cache_update_vit(
past_key_values, **generation_input
)
# add text
generation_input, newlens, new_rope = self.prepare_prompts(
curr_kvlens=newlens,
curr_rope=new_rope,
prompts=[prompt],
tokenizer=tokenizer,
new_token_ids=new_token_ids,
)
for k, v in generation_input.items():
if torch.is_tensor(v):
generation_input[k] = v.to(device)
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
past_key_values = self.forward_cache_update_text(
past_key_values, **generation_input
)
# decode
generation_input = self.prepare_start_tokens(newlens, new_rope, new_token_ids)
for k, v in generation_input.items():
if torch.is_tensor(v):
generation_input[k] = v.to(device)
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
unpacked_latent = self.generate_text(
past_key_values=past_key_values,
max_length=max_length,
do_sample=do_sample,
temperature=temperature,
end_token_id=new_token_ids["eos_token_id"],
**generation_input,
)
output = tokenizer.decode(unpacked_latent[:, 0])
output = output.split("<|im_end|>")[0].split("<|im_start|>")[1]
return output
# Copyright (c) 2022 Facebook, Inc. and its affiliates.
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: CC BY-NC 4.0
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under CC BY-NC 4.0, with the full license text
# available at https://github.com/facebookresearch/DiT/blob/main/LICENSE.txt.
#
# This modified file is released under the same license.
import math
import numpy as np
import torch
from torch import nn
from transformers.activations import ACT2FN
# --------------------------------------------------------
# 2D sine-cosine position embedding
# References:
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
# --------------------------------------------------------
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate(
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
# --------------------------------------------------------
# TimestepEmbedder
# Reference:
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
# --------------------------------------------------------
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class MLPconnector(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_act: str):
super().__init__()
self.activation_fn = ACT2FN[hidden_act]
self.fc1 = nn.Linear(in_dim, out_dim)
self.fc2 = nn.Linear(out_dim, out_dim)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class PositionEmbedding(nn.Module):
def __init__(self, max_num_patch_per_side, hidden_size):
super().__init__()
self.max_num_patch_per_side = max_num_patch_per_side
self.hidden_size = hidden_size
self.pos_embed = nn.Parameter(
torch.zeros(max_num_patch_per_side**2, hidden_size), requires_grad=False
)
self._init_weights()
def _init_weights(self):
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed = get_2d_sincos_pos_embed(
self.hidden_size, self.max_num_patch_per_side
)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float())
def forward(self, position_ids):
return self.pos_embed[position_ids]
# Copyright (c) 2024 The Qwen Team and The HuggingFace Inc. team.
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under Apache-2.0, with the full license text
# available at https://github.com/huggingface/transformers/blob/main/LICENSE.
#
# This modified file is released under the same license.
from dataclasses import dataclass
from functools import partial
from typing import List, Optional, Tuple
import torch
from flash_attn import flash_attn_varlen_func
from torch import nn
from torch.nn.attention import SDPBackend, sdpa_kernel
# from torch.nn.attention.flex_attention import flex_attention
from torch.nn.functional import scaled_dot_product_attention
from transformers.utils import ModelOutput
from ..cache_utils.taylorseer import (
cal_type,
derivative_approximation,
taylor_cache_init,
taylor_formula,
)
from ..qwen2.configuration_qwen2 import Qwen2Config as _Qwen2Config
from ..qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2MLP,
Qwen2PreTrainedModel,
Qwen2RMSNorm,
Qwen2RotaryEmbedding,
apply_rotary_pos_emb,
)
torch._dynamo.config.cache_size_limit = 512
torch._dynamo.config.accumulated_cache_size_limit = 4096
# flex_attention = torch.compile(flex_attention) # , dynamic=True, mode='max-autotune'
# flex_attention = torch.compile(flex_attention)
class Qwen2Config(_Qwen2Config):
r"""
This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen2Model`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 22016):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 32):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 28):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import Qwen2Model, Qwen2Config
>>> # Initializing a Qwen2 style configuration
>>> configuration = Qwen2Config()
>>> # Initializing a model from the Qwen2-7B style configuration
>>> model = Qwen2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
intermediate_size=22016,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=28,
attention_dropout=0.0,
is_causal=True,
_attn_implementation="flash_attention_2",
qk_norm=True,
layer_module="Qwen2DecoderLayer",
freeze_und=False,
**kwargs,
):
super().__init__(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
hidden_act=hidden_act,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
rms_norm_eps=rms_norm_eps,
use_cache=use_cache,
tie_word_embeddings=tie_word_embeddings,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
use_sliding_window=use_sliding_window,
sliding_window=sliding_window,
max_window_layers=max_window_layers,
attention_dropout=attention_dropout,
is_causal=is_causal,
_attn_implementation=_attn_implementation,
**kwargs,
)
self.qk_norm = qk_norm
self.layer_module = layer_module
self.freeze_und = freeze_und
class NaiveCache:
def __init__(self, num_layers):
self.key_cache = {k: None for k in range(num_layers)}
self.value_cache = {k: None for k in range(num_layers)}
@property
def num_layers(self):
return len(self.key_cache)
@property
def seq_lens(self):
if self.key_cache[0] is not None:
return self.key_cache[0].shape[0]
else:
return 0
@dataclass
class BaseNavitOutputWithPast(ModelOutput):
packed_query_sequence: torch.FloatTensor = None
past_key_values: Optional[NaiveCache] = None
def pad_sequence(tensor, pad_size):
H, L, D = tensor.shape
pad_tensor = tensor.new_zeros((H, pad_size, D))
return torch.cat([tensor, pad_tensor], dim=1)
class PackedAttention(Qwen2Attention):
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
if self.config.qk_norm:
self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
else:
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask: List[torch.Tensor],
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
):
packed_query_states = self.q_proj(packed_sequence).view(
-1, self.num_heads, self.head_dim
)
packed_key_states = self.k_proj(packed_sequence).view(
-1, self.num_key_value_heads, self.head_dim
)
packed_value_states = self.v_proj(packed_sequence).view(
-1, self.num_key_value_heads, self.head_dim
)
packed_query_states = self.q_norm(packed_query_states)
packed_key_states = self.k_norm(packed_key_states)
packed_cos, packed_sin = packed_position_embeddings
packed_query_states, packed_key_states = apply_rotary_pos_emb(
packed_query_states,
packed_key_states,
packed_cos,
packed_sin,
unsqueeze_dim=1,
)
if isinstance(attention_mask, List):
packed_key_states = packed_key_states[:, :, None, :].repeat(
1, 1, self.num_key_value_groups, 1
)
packed_key_states = packed_key_states.reshape(
-1, self.num_heads, self.head_dim
)
packed_value_states = packed_value_states[:, :, None, :].repeat(
1, 1, self.num_key_value_groups, 1
)
packed_value_states = packed_value_states.reshape(
-1, self.num_heads, self.head_dim
)
unpacked_query_states = packed_query_states.transpose(0, 1).split(
sample_lens, dim=1
)
unpacked_key_states = packed_key_states.transpose(0, 1).split(
sample_lens, dim=1
)
unpacked_value_states = packed_value_states.transpose(0, 1).split(
sample_lens, dim=1
)
upacked_attn_output = []
for (
query_states,
key_states,
value_states,
attention_mask_per_sample,
) in zip(
unpacked_query_states,
unpacked_key_states,
unpacked_value_states,
attention_mask,
):
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
attn_output = scaled_dot_product_attention(
query_states.to(torch.bfloat16).unsqueeze(0),
key_states.to(torch.bfloat16).unsqueeze(0),
value_states.to(torch.bfloat16).unsqueeze(0),
attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0),
)
upacked_attn_output.append(attn_output.squeeze(0))
packed_attn_output = torch.cat(upacked_attn_output, dim=1)
else:
pad_size = sum(sample_lens) - packed_query_states.shape[0]
packed_query_states = pad_sequence(
packed_query_states.permute(1, 0, 2), pad_size
)
packed_key_states = pad_sequence(
packed_key_states.permute(1, 0, 2), pad_size
)
packed_value_states = pad_sequence(
packed_value_states.permute(1, 0, 2), pad_size
)
packed_attn_output = flex_attention(
packed_query_states.unsqueeze(0),
packed_key_states.unsqueeze(0),
packed_value_states.unsqueeze(0),
enable_gqa=True,
block_mask=attention_mask,
)
end_index = packed_attn_output.shape[2] - pad_size
packed_attn_output = packed_attn_output[0, :, :end_index, :]
packed_attn_output = packed_attn_output.transpose(0, 1).reshape(
-1, self.hidden_size
)
packed_attn_output = self.o_proj(packed_attn_output)
return packed_attn_output
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_embeddings: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
):
packed_query_states = self.q_proj(packed_query_sequence).view(
-1, self.num_heads, self.head_dim
)
packed_key_states = self.k_proj(packed_query_sequence).view(
-1, self.num_key_value_heads, self.head_dim
)
packed_value_states = self.v_proj(packed_query_sequence).view(
-1, self.num_key_value_heads, self.head_dim
)
packed_query_states = self.q_norm(packed_query_states)
packed_key_states = self.k_norm(packed_key_states)
packed_cos, packed_sin = packed_query_position_embeddings
packed_query_states, packed_key_states = apply_rotary_pos_emb(
packed_query_states,
packed_key_states,
packed_cos,
packed_sin,
unsqueeze_dim=1,
)
packed_query_states = packed_query_states.to(torch.bfloat16)
packed_key_states = packed_key_states.to(torch.bfloat16)
packed_value_states = packed_value_states.to(torch.bfloat16)
if (
past_key_values is not None
and past_key_values.key_cache[self.layer_idx] is not None
):
past_key_states = past_key_values.key_cache[self.layer_idx]
past_value_states = past_key_values.value_cache[self.layer_idx]
seqlens = sum(query_lens) + sum(key_values_lens)
merged_key_states = past_key_states.new_zeros(
(seqlens, self.num_key_value_heads, self.head_dim)
)
merged_value_states = past_key_states.new_zeros(
(seqlens, self.num_key_value_heads, self.head_dim)
)
merged_key_states[packed_query_indexes] = packed_key_states
merged_key_states[packed_key_value_indexes] = past_key_states
merged_value_states[packed_query_indexes] = packed_value_states
merged_value_states[packed_key_value_indexes] = past_value_states
key_values_lens = key_values_lens + query_lens
else:
merged_key_states = packed_key_states
merged_value_states = packed_value_states
key_values_lens = query_lens
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0))
cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(key_values_lens, dim=0), (1, 0)
)
packed_attn_output = flash_attn_varlen_func(
q=packed_query_states,
k=merged_key_states,
v=merged_value_states,
cu_seqlens_q=cu_seqlens_q.to(torch.int32),
cu_seqlens_k=cu_seqlens_k.to(torch.int32),
max_seqlen_q=max(query_lens).item(),
max_seqlen_k=max(key_values_lens).item(),
causal=is_causal,
)
packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size)
packed_attn_output = self.o_proj(packed_attn_output)
if update_past_key_values:
past_key_values.key_cache[self.layer_idx] = merged_key_states
past_key_values.value_cache[self.layer_idx] = merged_value_states
return packed_attn_output, past_key_values
class PackedAttentionMoT(Qwen2Attention):
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
if self.config.qk_norm:
self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.q_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
else:
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()
self.q_norm_moe_gen = nn.Identity()
self.k_norm_moe_gen = nn.Identity()
self.q_proj_moe_gen = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=True
)
self.k_proj_moe_gen = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
)
self.v_proj_moe_gen = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
)
self.o_proj_moe_gen = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False
)
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
packed_und_token_indexes: torch.LongTensor,
packed_gen_token_indexes: torch.LongTensor,
):
packed_query_states = packed_sequence.new_zeros(
(packed_sequence.shape[0], self.num_heads * self.head_dim)
)
packed_key_states = packed_sequence.new_zeros(
(packed_sequence.shape[0], self.num_key_value_heads * self.head_dim)
)
packed_value_states = packed_sequence.new_zeros(
(packed_sequence.shape[0], self.num_key_value_heads * self.head_dim)
)
packed_sequence_und = packed_sequence[packed_und_token_indexes]
packed_sequence_gen = packed_sequence[packed_gen_token_indexes]
packed_query_states[packed_und_token_indexes] = self.q_proj(packed_sequence_und)
packed_query_states[packed_gen_token_indexes] = self.q_proj_moe_gen(
packed_sequence_gen
)
packed_key_states[packed_und_token_indexes] = self.k_proj(packed_sequence_und)
packed_key_states[packed_gen_token_indexes] = self.k_proj_moe_gen(
packed_sequence_gen
)
packed_value_states[packed_und_token_indexes] = self.v_proj(packed_sequence_und)
packed_value_states[packed_gen_token_indexes] = self.v_proj_moe_gen(
packed_sequence_gen
)
packed_query_states = packed_query_states.view(
-1, self.num_heads, self.head_dim
)
packed_key_states = packed_key_states.view(
-1, self.num_key_value_heads, self.head_dim
)
packed_value_states = packed_value_states.view(
-1, self.num_key_value_heads, self.head_dim
)
if self.config.freeze_und:
packed_value_states[packed_und_token_indexes] = packed_value_states[
packed_und_token_indexes
].detach()
packed_query_states_ = packed_query_states.new_zeros(packed_query_states.shape)
packed_key_states_ = packed_key_states.new_zeros(packed_key_states.shape)
packed_query_states_[packed_und_token_indexes] = self.q_norm(
packed_query_states[packed_und_token_indexes]
)
if self.config.freeze_und:
packed_query_states_[packed_und_token_indexes] = packed_query_states_[
packed_und_token_indexes
].detach()
packed_query_states_[packed_gen_token_indexes] = self.q_norm_moe_gen(
packed_query_states[packed_gen_token_indexes]
)
packed_key_states_[packed_und_token_indexes] = self.k_norm(
packed_key_states[packed_und_token_indexes]
)
if self.config.freeze_und:
packed_key_states_[packed_und_token_indexes] = packed_key_states_[
packed_und_token_indexes
].detach()
packed_key_states_[packed_gen_token_indexes] = self.k_norm_moe_gen(
packed_key_states[packed_gen_token_indexes]
)
packed_cos, packed_sin = packed_position_embeddings
packed_query_states_, packed_key_states_ = apply_rotary_pos_emb(
packed_query_states_,
packed_key_states_,
packed_cos,
packed_sin,
unsqueeze_dim=1,
)
if isinstance(attention_mask, List):
packed_key_states_ = packed_key_states_[:, :, None, :].repeat(
1, 1, self.num_key_value_groups, 1
)
packed_key_states_ = packed_key_states_.reshape(
-1, self.num_heads, self.head_dim
)
packed_value_states = packed_value_states[:, :, None, :].repeat(
1, 1, self.num_key_value_groups, 1
)
packed_value_states = packed_value_states.reshape(
-1, self.num_heads, self.head_dim
)
unpacked_query_states = packed_query_states_.transpose(0, 1).split(
sample_lens, dim=1
)
unpacked_key_states = packed_key_states_.transpose(0, 1).split(
sample_lens, dim=1
)
unpacked_value_states = packed_value_states.transpose(0, 1).split(
sample_lens, dim=1
)
upacked_attn_output = []
for (
query_states,
key_states,
value_states,
attention_mask_per_sample,
) in zip(
unpacked_query_states,
unpacked_key_states,
unpacked_value_states,
attention_mask,
):
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
attn_output = scaled_dot_product_attention(
query_states.to(torch.bfloat16).unsqueeze(0),
key_states.to(torch.bfloat16).unsqueeze(0),
value_states.to(torch.bfloat16).unsqueeze(0),
attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0),
)
upacked_attn_output.append(attn_output.squeeze(0))
packed_attn_output = torch.cat(upacked_attn_output, dim=1)
else:
pad_size = sum(sample_lens) - packed_query_states.shape[0]
packed_query_states_ = pad_sequence(
packed_query_states_.permute(1, 0, 2), pad_size
)
packed_key_states_ = pad_sequence(
packed_key_states_.permute(1, 0, 2), pad_size
)
packed_value_states = pad_sequence(
packed_value_states.permute(1, 0, 2), pad_size
)
packed_attn_output = flex_attention(
packed_query_states_.unsqueeze(0), # 1, num_head, L, head_dim
packed_key_states_.unsqueeze(0),
packed_value_states.unsqueeze(0),
enable_gqa=True,
block_mask=attention_mask,
)
end_index = packed_attn_output.shape[2] - pad_size
packed_attn_output = packed_attn_output[0, :, :end_index, :]
packed_attn_output = packed_attn_output.transpose(0, 1).reshape(
-1, self.num_heads * self.head_dim
)
packed_attn_output_ = packed_attn_output.new_zeros(packed_attn_output.shape)
packed_attn_output_[packed_und_token_indexes] = self.o_proj(
packed_attn_output[packed_und_token_indexes]
)
packed_attn_output_[packed_gen_token_indexes] = self.o_proj_moe_gen(
packed_attn_output[packed_gen_token_indexes]
)
return packed_attn_output_
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_embeddings: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
):
if mode == "und":
packed_query_states = self.q_proj(packed_query_sequence).view(
-1, self.num_heads, self.head_dim
)
packed_key_states = self.k_proj(packed_query_sequence).view(
-1, self.num_key_value_heads, self.head_dim
)
packed_value_states = self.v_proj(packed_query_sequence).view(
-1, self.num_key_value_heads, self.head_dim
)
packed_query_states = self.q_norm(packed_query_states)
packed_key_states = self.k_norm(packed_key_states)
elif mode == "gen":
packed_query_sequence = packed_query_sequence.to(torch.bfloat16)
packed_query_states = packed_query_sequence.new_zeros(
(packed_query_sequence.shape[0], self.num_heads * self.head_dim)
)
packed_key_states = packed_query_sequence.new_zeros(
(
packed_query_sequence.shape[0],
self.num_key_value_heads * self.head_dim,
)
)
packed_value_states = packed_query_sequence.new_zeros(
(
packed_query_sequence.shape[0],
self.num_key_value_heads * self.head_dim,
)
)
packed_text_query_sequence = packed_query_sequence[packed_text_indexes]
packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes]
packed_query_states[packed_text_indexes] = self.q_proj(
packed_text_query_sequence
)
packed_query_states[packed_vae_token_indexes] = self.q_proj_moe_gen(
packed_vae_query_sequence
)
packed_key_states[packed_text_indexes] = self.k_proj(
packed_text_query_sequence
)
packed_key_states[packed_vae_token_indexes] = self.k_proj_moe_gen(
packed_vae_query_sequence
)
packed_value_states[packed_text_indexes] = self.v_proj(
packed_text_query_sequence
)
packed_value_states[packed_vae_token_indexes] = self.v_proj_moe_gen(
packed_vae_query_sequence
)
packed_query_states = packed_query_states.view(
-1, self.num_heads, self.head_dim
)
packed_key_states = packed_key_states.view(
-1, self.num_key_value_heads, self.head_dim
)
packed_value_states = packed_value_states.view(
-1, self.num_key_value_heads, self.head_dim
)
packed_query_states = packed_query_states.to(torch.float32)
packed_query_states[packed_text_indexes] = self.q_norm(
packed_query_states[packed_text_indexes]
)
packed_query_states[packed_vae_token_indexes] = self.q_norm_moe_gen(
packed_query_states[packed_vae_token_indexes]
)
packed_key_states = packed_key_states.to(torch.float32)
packed_key_states[packed_text_indexes] = self.k_norm(
packed_key_states[packed_text_indexes]
)
packed_key_states[packed_vae_token_indexes] = self.k_norm_moe_gen(
packed_key_states[packed_vae_token_indexes]
)
packed_cos, packed_sin = packed_query_position_embeddings
packed_query_states, packed_key_states = apply_rotary_pos_emb(
packed_query_states,
packed_key_states,
packed_cos,
packed_sin,
unsqueeze_dim=1,
)
packed_query_states = packed_query_states.to(torch.bfloat16)
packed_key_states = packed_key_states.to(torch.bfloat16)
packed_value_states = packed_value_states.to(torch.bfloat16)
if (
past_key_values is not None
and past_key_values.key_cache[self.layer_idx] is not None
):
past_key_states = past_key_values.key_cache[self.layer_idx]
past_value_states = past_key_values.value_cache[self.layer_idx]
seqlens = sum(query_lens) + sum(key_values_lens)
merged_key_states = past_key_states.new_zeros(
size=[seqlens, self.num_key_value_heads, self.head_dim]
)
merged_value_states = past_key_states.new_zeros(
size=[seqlens, self.num_key_value_heads, self.head_dim]
)
merged_key_states[packed_query_indexes] = packed_key_states
merged_key_states[packed_key_value_indexes] = past_key_states
merged_value_states[packed_query_indexes] = packed_value_states
merged_value_states[packed_key_value_indexes] = past_value_states
key_values_lens = key_values_lens + query_lens
else:
merged_key_states = packed_key_states
merged_value_states = packed_value_states
key_values_lens = query_lens
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0))
cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(key_values_lens, dim=0), (1, 0)
)
packed_attn_output = flash_attn_varlen_func(
q=packed_query_states,
k=merged_key_states,
v=merged_value_states,
cu_seqlens_q=cu_seqlens_q.to(torch.int32),
cu_seqlens_k=cu_seqlens_k.to(torch.int32),
max_seqlen_q=max(query_lens).item(),
max_seqlen_k=max(key_values_lens).item(),
causal=is_causal,
)
packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size)
if mode == "und":
packed_attn_output = self.o_proj(packed_attn_output)
elif mode == "gen":
packed_attn_output[packed_text_indexes] = self.o_proj(
packed_attn_output[packed_text_indexes]
)
packed_attn_output[packed_vae_token_indexes] = self.o_proj_moe_gen(
packed_attn_output[packed_vae_token_indexes]
)
if update_past_key_values:
past_key_values.key_cache[self.layer_idx] = merged_key_states
past_key_values.value_cache[self.layer_idx] = merged_value_states
return packed_attn_output, past_key_values
class Qwen2DecoderLayer(nn.Module):
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = PackedAttention(config, layer_idx)
self.mlp = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
residual = packed_sequence
packed_sequence = self.input_layernorm(packed_sequence)
# Self Attention
packed_sequence = self.self_attn(
packed_sequence=packed_sequence,
sample_lens=sample_lens,
attention_mask=attention_mask,
packed_position_embeddings=packed_position_embeddings,
)
packed_sequence = residual + packed_sequence
# Fully Connected
residual = packed_sequence
packed_sequence = self.post_attention_layernorm(packed_sequence)
packed_sequence = self.mlp(packed_sequence)
packed_sequence = residual + packed_sequence
return packed_sequence
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_embeddings: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
) -> BaseNavitOutputWithPast:
residual = packed_query_sequence
packed_query_sequence = self.input_layernorm(packed_query_sequence)
# Self Attention
packed_query_sequence, past_key_values = self.self_attn(
packed_query_sequence=packed_query_sequence,
query_lens=query_lens,
packed_query_position_embeddings=packed_query_position_embeddings,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=update_past_key_values,
is_causal=is_causal,
)
packed_query_sequence = residual + packed_query_sequence
# Fully Connected
residual = packed_query_sequence
packed_query_sequence = self.post_attention_layernorm(packed_query_sequence)
packed_query_sequence = self.mlp(packed_query_sequence)
packed_query_sequence = residual + packed_query_sequence
return packed_query_sequence, past_key_values
class Qwen2MoTDecoderLayer(nn.Module):
def __init__(
self,
config,
layer_idx: Optional[int] = None,
attn_module: Optional[Qwen2Attention] = PackedAttentionMoT,
):
super().__init__()
self.hidden_size = config.hidden_size
self.freeze_und = config.freeze_und
self.self_attn = attn_module(config, layer_idx)
self.mlp = Qwen2MLP(config)
self.mlp_moe_gen = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm_moe_gen = Qwen2RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_attention_layernorm = Qwen2RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_attention_layernorm_moe_gen = Qwen2RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
packed_und_token_indexes: torch.LongTensor,
packed_gen_token_indexes: torch.LongTensor,
) -> torch.Tensor:
residual = packed_sequence
packed_sequence_ = packed_sequence.new_zeros(packed_sequence.shape)
packed_sequence_[packed_und_token_indexes] = self.input_layernorm(
packed_sequence[packed_und_token_indexes]
)
packed_sequence_[packed_gen_token_indexes] = self.input_layernorm_moe_gen(
packed_sequence[packed_gen_token_indexes]
)
# Self Attention
packed_sequence_ = self.self_attn(
packed_sequence=packed_sequence_,
sample_lens=sample_lens,
attention_mask=attention_mask,
packed_position_embeddings=packed_position_embeddings,
packed_und_token_indexes=packed_und_token_indexes,
packed_gen_token_indexes=packed_gen_token_indexes,
)
if self.freeze_und:
packed_sequence_[packed_und_token_indexes] = packed_sequence_[
packed_und_token_indexes
].detach()
packed_sequence = residual + packed_sequence_
# Fully Connected
residual = packed_sequence
packed_sequence_ = packed_sequence.new_zeros(packed_sequence.shape)
packed_sequence_[packed_und_token_indexes] = self.mlp(
self.post_attention_layernorm(packed_sequence[packed_und_token_indexes])
)
if self.freeze_und:
packed_sequence_[packed_und_token_indexes] = packed_sequence_[
packed_und_token_indexes
].detach()
packed_sequence_[packed_gen_token_indexes] = self.mlp_moe_gen(
self.post_attention_layernorm_moe_gen(
packed_sequence[packed_gen_token_indexes]
)
)
packed_sequence = residual + packed_sequence_
return packed_sequence
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_embeddings: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
) -> BaseNavitOutputWithPast:
enable_taylorseer = getattr(self, "enable_taylorseer", False)
if enable_taylorseer and self.current["type"] == "full":
self.current["module"] = "total"
taylor_cache_init(cache_dic=self.cache_dic, current=self.current)
if not enable_taylorseer or (
enable_taylorseer and self.current["type"] == "full"
):
residual = packed_query_sequence
if mode == "und":
packed_query_sequence = self.input_layernorm(packed_query_sequence)
elif mode == "gen":
packed_query_sequence_ = torch.zeros_like(packed_query_sequence)
packed_query_sequence_[packed_text_indexes] = self.input_layernorm(
packed_query_sequence[packed_text_indexes]
)
packed_query_sequence_[packed_vae_token_indexes] = (
self.input_layernorm_moe_gen(
packed_query_sequence[packed_vae_token_indexes]
)
)
packed_query_sequence = packed_query_sequence_
# Self Attention
packed_query_sequence, past_key_values = self.self_attn(
packed_query_sequence=packed_query_sequence,
query_lens=query_lens,
packed_query_position_embeddings=packed_query_position_embeddings,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=update_past_key_values,
is_causal=is_causal,
mode=mode,
packed_vae_token_indexes=packed_vae_token_indexes,
packed_text_indexes=packed_text_indexes,
)
packed_query_sequence = residual + packed_query_sequence
# Fully Connected
residual = packed_query_sequence
if mode == "und":
packed_query_sequence = self.post_attention_layernorm(
packed_query_sequence
)
packed_query_sequence = self.mlp(packed_query_sequence)
elif mode == "gen":
packed_text_query_sequence = packed_query_sequence[packed_text_indexes]
packed_vae_query_sequence = packed_query_sequence[
packed_vae_token_indexes
]
packed_text_query_sequence = self.post_attention_layernorm(
packed_text_query_sequence
).to(torch.bfloat16)
packed_vae_query_sequence = self.post_attention_layernorm_moe_gen(
packed_vae_query_sequence
).to(torch.bfloat16)
packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(
torch.bfloat16
)
packed_query_sequence_[packed_text_indexes] = self.mlp(
packed_text_query_sequence
)
packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(
packed_vae_query_sequence
)
packed_query_sequence = packed_query_sequence_
packed_query_sequence = residual + packed_query_sequence
if enable_taylorseer:
if self.current["type"] == "full":
derivative_approximation(
cache_dic=self.cache_dic,
current=self.current,
feature=packed_query_sequence,
)
elif self.current["type"] == "Taylor":
self.current["module"] = "total"
packed_query_sequence = taylor_formula(
cache_dic=self.cache_dic, current=self.current
)
return packed_query_sequence, past_key_values
class Qwen2MoEDecoderLayer(nn.Module):
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = PackedAttention(config, layer_idx)
self.mlp = Qwen2MLP(config)
self.mlp_moe_gen = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
packed_und_token_indexes: torch.LongTensor,
packed_gen_token_indexes: torch.LongTensor,
) -> torch.Tensor:
residual = packed_sequence
packed_sequence = self.input_layernorm(packed_sequence)
# Self Attention
packed_sequence = self.self_attn(
packed_sequence=packed_sequence,
sample_lens=sample_lens,
attention_mask=attention_mask,
packed_position_embeddings=packed_position_embeddings,
)
packed_sequence = residual + packed_sequence
# Fully Connected
residual = packed_sequence
packed_sequence = self.post_attention_layernorm(packed_sequence)
packed_sequence_new = packed_sequence.new_zeros(packed_sequence.shape)
packed_sequence_und = self.mlp(packed_sequence[packed_und_token_indexes])
packed_sequence_gen = self.mlp_moe_gen(
packed_sequence[packed_gen_token_indexes]
)
packed_sequence_new[packed_und_token_indexes] = packed_sequence_und
packed_sequence_new[packed_gen_token_indexes] = packed_sequence_gen
packed_sequence = residual + packed_sequence_new
return packed_sequence
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_embeddings: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
) -> BaseNavitOutputWithPast:
residual = packed_query_sequence
packed_query_sequence = self.input_layernorm(packed_query_sequence)
# Self Attention
packed_query_sequence, past_key_values = self.self_attn(
packed_query_sequence=packed_query_sequence,
query_lens=query_lens,
packed_query_position_embeddings=packed_query_position_embeddings,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=update_past_key_values,
is_causal=is_causal,
)
packed_query_sequence = residual + packed_query_sequence
# Fully Connected
residual = packed_query_sequence
packed_query_sequence = self.post_attention_layernorm(packed_query_sequence)
if mode == "und":
packed_query_sequence = self.mlp(packed_query_sequence)
elif mode == "gen":
packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(
torch.bfloat16
)
packed_query_sequence_[packed_text_indexes] = self.mlp(
packed_query_sequence[packed_text_indexes]
)
packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(
packed_query_sequence[packed_vae_token_indexes]
)
packed_query_sequence = packed_query_sequence_
packed_query_sequence = residual + packed_query_sequence
return packed_query_sequence, past_key_values
Decoder_layer_dict = {
"Qwen2DecoderLayer": Qwen2DecoderLayer,
"Qwen2MoEDecoderLayer": Qwen2MoEDecoderLayer,
"Qwen2MoTDecoderLayer": partial(
Qwen2MoTDecoderLayer, attn_module=PackedAttentionMoT
),
}
class Qwen2Model(Qwen2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.use_moe = "Mo" in config.layer_module
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
layer_module = Decoder_layer_dict[config.layer_module]
self.layers = nn.ModuleList(
[
layer_module(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if self.use_moe:
self.norm_moe_gen = Qwen2RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
# Initialize weights and apply final processing
self.post_init()
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_ids: torch.Tensor,
packed_und_token_indexes: Optional[torch.LongTensor] = None,
packed_gen_token_indexes: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
if self.config.freeze_und:
packed_sequence[packed_und_token_indexes] = packed_sequence[
packed_und_token_indexes
].detach()
# create position embeddings to be shared across the decoder layers
cos, sin = self.rotary_emb(packed_sequence, packed_position_ids.unsqueeze(0))
cos = cos.squeeze(0)
sin = sin.squeeze(0)
packed_position_embeddings = (cos, sin)
extra_inputs = {}
if self.use_moe:
assert packed_und_token_indexes is not None
if packed_gen_token_indexes is None:
packed_gen_token_indexes = packed_und_token_indexes.new_ones(size=[0])
extra_inputs.update(
packed_und_token_indexes=packed_und_token_indexes,
packed_gen_token_indexes=packed_gen_token_indexes,
)
for decoder_layer in self.layers:
packed_sequence = decoder_layer(
packed_sequence=packed_sequence,
sample_lens=sample_lens,
attention_mask=attention_mask,
packed_position_embeddings=packed_position_embeddings,
**extra_inputs,
)
if self.use_moe:
packed_sequence_ = torch.zeros_like(packed_sequence)
packed_sequence_[packed_und_token_indexes] = self.norm(
packed_sequence[packed_und_token_indexes]
)
if self.config.freeze_und:
packed_sequence_[packed_und_token_indexes] = packed_sequence_[
packed_und_token_indexes
].detach()
packed_sequence_[packed_gen_token_indexes] = self.norm_moe_gen(
packed_sequence[packed_gen_token_indexes]
)
return packed_sequence_
else:
return self.norm(packed_sequence)
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_ids: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
) -> BaseNavitOutputWithPast:
enable_taylorseer = getattr(self, "enable_taylorseer", False)
if enable_taylorseer:
cal_type(self.cache_dic, self.current)
self.current["stream"] = "layers_stream"
# create position embeddings to be shared across the decoder layers
cos, sin = self.rotary_emb(
packed_query_sequence, packed_query_position_ids.unsqueeze(0)
)
cos = cos.squeeze(0)
sin = sin.squeeze(0)
packed_query_position_embeddings = (cos, sin)
extra_inputs = {}
if self.use_moe:
extra_inputs.update(mode=mode)
if mode == "gen":
assert packed_vae_token_indexes is not None
assert packed_text_indexes is not None
extra_inputs.update(
packed_vae_token_indexes=packed_vae_token_indexes,
packed_text_indexes=packed_text_indexes,
)
for layer_idx, decoder_layer in enumerate(self.layers):
if enable_taylorseer:
decoder_layer.current = self.current
decoder_layer.cache_dic = self.cache_dic
decoder_layer.enable_taylorseer = True
self.current["layer"] = layer_idx
packed_query_sequence, past_key_values = decoder_layer(
packed_query_sequence=packed_query_sequence,
query_lens=query_lens,
packed_query_position_embeddings=packed_query_position_embeddings,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=update_past_key_values,
is_causal=is_causal,
**extra_inputs,
)
if self.use_moe:
if mode == "und":
packed_query_sequence = self.norm(packed_query_sequence)
elif mode == "gen":
packed_query_sequence_ = torch.zeros_like(packed_query_sequence)
packed_query_sequence_[packed_text_indexes] = self.norm(
packed_query_sequence[packed_text_indexes]
)
packed_query_sequence_[packed_vae_token_indexes] = self.norm_moe_gen(
packed_query_sequence[packed_vae_token_indexes]
)
packed_query_sequence = packed_query_sequence_
else:
packed_query_sequence = self.norm(packed_query_sequence)
if enable_taylorseer:
self.current["step"] += 1
return BaseNavitOutputWithPast(
packed_query_sequence=packed_query_sequence,
past_key_values=past_key_values,
)
class Qwen2ForCausalLM(Qwen2PreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = Qwen2Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def init_moe(self):
for name, param in self.named_parameters():
if "moe_gen" in name:
original_name = name.replace("_moe_gen", "")
param.data.copy_(self.state_dict()[original_name].data)
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_ids: torch.Tensor,
packed_und_token_indexes: Optional[torch.LongTensor] = None,
packed_gen_token_indexes: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
outputs = self.model(
packed_sequence=packed_sequence,
sample_lens=sample_lens,
packed_position_ids=packed_position_ids,
attention_mask=attention_mask,
packed_und_token_indexes=packed_und_token_indexes,
packed_gen_token_indexes=packed_gen_token_indexes,
)
return outputs
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_ids: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
) -> BaseNavitOutputWithPast:
outputs = self.model(
packed_query_sequence=packed_query_sequence,
query_lens=query_lens,
packed_query_position_ids=packed_query_position_ids,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=update_past_key_values,
is_causal=is_causal,
mode=mode,
packed_vae_token_indexes=packed_vae_token_indexes,
packed_text_indexes=packed_text_indexes,
)
return outputs
# Copyright (c) 2024 The HuggingFace Inc. team.
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under Apache-2.0, with the full license text
# available at https://github.com/huggingface/transformers/blob/main/LICENSE.
#
# This modified file is released under the same license.
import torch
from flash_attn import flash_attn_varlen_func
from torch import nn
from transformers.activations import ACT2FN
from ..siglip.configuration_siglip import SiglipVisionConfig as _SiglipVisionConfig
from ..siglip.modeling_siglip import SiglipAttention, SiglipPreTrainedModel
class SiglipVisionConfig(_SiglipVisionConfig):
r"""
This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
Number of channels in the input images.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
Example:
```python
>>> from transformers import SiglipVisionConfig, SiglipVisionModel
>>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
>>> configuration = SiglipVisionConfig()
>>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
>>> model = SiglipVisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "siglip_vision_model"
def __init__(
self,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
num_channels=3,
image_size=224,
patch_size=16,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
rope=True,
**kwargs,
):
super().__init__(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_channels=num_channels,
image_size=image_size,
patch_size=patch_size,
hidden_act=hidden_act,
layer_norm_eps=layer_norm_eps,
attention_dropout=attention_dropout,
**kwargs,
)
self.rope = rope
class RotaryEmbedding2D(torch.nn.Module):
def __init__(self, dim, max_h, max_w, base=10000):
super().__init__()
freq = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim
inv_freq = 1.0 / (base**freq)
grid_h = torch.arange(0, max_h)
grid_h = grid_h.to(inv_freq.dtype)
grid_h = grid_h[:, None].repeat(1, max_w)
grid_w = torch.arange(0, max_w)
grid_w = grid_w.to(inv_freq.dtype)
grid_w = grid_w[None, :].repeat(max_h, 1)
cos_h, sin_h = self._forward_one_side(grid_h, inv_freq)
cos_w, sin_w = self._forward_one_side(grid_w, inv_freq)
self.register_buffer("cos_h", cos_h)
self.register_buffer("sin_h", sin_h)
self.register_buffer("cos_w", cos_w)
self.register_buffer("sin_w", sin_w)
def _forward_one_side(self, grid, inv_freq):
freqs = grid[..., None] * inv_freq[None, None, :]
emb = torch.cat((freqs, freqs), dim=-1).flatten(0, 1)
return emb.cos(), emb.sin()
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
# unsqueeze due to the head dimension
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.num_patches_per_side = self.image_size // self.patch_size
self.num_patches = self.num_patches_per_side**2
self.num_positions = self.num_patches
if not config.rope:
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
def convert_conv2d_to_linear(self, config, meta=False):
if meta:
linear_patch_embedding = nn.Linear(
config.num_channels * self.patch_size**2,
self.embed_dim,
bias=True,
device="meta",
)
else:
linear_patch_embedding = nn.Linear(
config.num_channels * self.patch_size**2, self.embed_dim, bias=True
)
W = self.patch_embedding.weight.permute(0, 2, 3, 1).reshape(
self.embed_dim, config.num_channels * self.patch_size**2
)
linear_patch_embedding.weight.data = W
linear_patch_embedding.bias.data = self.patch_embedding.bias.data
del self.patch_embedding
self.patch_embedding = linear_patch_embedding
def forward(
self,
packed_pixel_values: torch.FloatTensor,
packed_flattened_position_ids: torch.LongTensor,
) -> torch.Tensor:
patch_embeds = self.patch_embedding(packed_pixel_values)
if not self.config.rope:
embeddings = patch_embeds + self.position_embedding(
packed_flattened_position_ids
)
else:
embeddings = patch_embeds
return embeddings
class SiglipFlashAttention2(SiglipAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
cos_h: torch.Tensor = None,
sin_h: torch.Tensor = None,
cos_w: torch.Tensor = None,
sin_w: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
total_q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(total_q_len, self.num_heads, self.head_dim)
key_states = key_states.view(total_q_len, self.num_heads, self.head_dim)
value_states = value_states.view(total_q_len, self.num_heads, self.head_dim)
if self.config.rope:
qh, qw = (
query_states[:, :, : self.head_dim // 2],
query_states[:, :, self.head_dim // 2 :],
)
kh, kw = (
key_states[:, :, : self.head_dim // 2],
key_states[:, :, self.head_dim // 2 :],
)
qh, kh = apply_rotary_pos_emb(qh, kh, cos_h, sin_h)
qw, kw = apply_rotary_pos_emb(qw, kw, cos_w, sin_w)
query_states = torch.cat([qh, qw], dim=-1)
key_states = torch.cat([kh, kw], dim=-1)
attn_output = flash_attn_varlen_func(
query_states.to(torch.bfloat16),
key_states.to(torch.bfloat16),
value_states.to(torch.bfloat16),
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
causal=False,
)
attn_output = self.out_proj(attn_output.reshape(total_q_len, -1))
return attn_output
class SiglipMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class SiglipEncoderLayer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = SiglipFlashAttention2(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
cos_h: torch.Tensor = None,
sin_h: torch.Tensor = None,
cos_w: torch.Tensor = None,
sin_w: torch.Tensor = None,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
cos_h=cos_h,
sin_h=sin_h,
cos_w=cos_w,
sin_w=sin_w,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class SiglipEncoder(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]
)
def forward(
self,
inputs_embeds: torch.Tensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
cos_h: torch.Tensor = None,
sin_h: torch.Tensor = None,
cos_w: torch.Tensor = None,
sin_w: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
cu_seqlens,
max_seqlen,
cos_h=cos_h,
sin_h=sin_h,
cos_w=cos_w,
sin_w=sin_w,
)
return hidden_states
class SiglipVisionTransformer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config)
if config.rope:
max_size = config.image_size // config.patch_size
dim_head = config.hidden_size // config.num_attention_heads
self.rope = RotaryEmbedding2D(dim_head // 2, max_size, max_size)
self.encoder = SiglipEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
def forward(
self,
packed_pixel_values: torch.Tensor,
packed_flattened_position_ids: torch.LongTensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
) -> torch.Tensor:
hidden_states = self.embeddings(
packed_pixel_values=packed_pixel_values,
packed_flattened_position_ids=packed_flattened_position_ids,
)
extra_inputs = {}
if self.config.rope:
extra_inputs.update(
cos_h=self.rope.cos_h[packed_flattened_position_ids],
sin_h=self.rope.sin_h[packed_flattened_position_ids],
cos_w=self.rope.cos_w[packed_flattened_position_ids],
sin_w=self.rope.sin_w[packed_flattened_position_ids],
)
last_hidden_state = self.encoder(
inputs_embeds=hidden_states,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
**extra_inputs,
)
last_hidden_state = self.post_layernorm(last_hidden_state)
return last_hidden_state
class SiglipVisionModel(SiglipPreTrainedModel):
config_class = SiglipVisionConfig
main_input_name = "packed_pixel_values"
def __init__(self, config: SiglipVisionConfig):
super().__init__(config)
self.vision_model = SiglipVisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def forward(
self,
packed_pixel_values: torch.Tensor,
packed_flattened_position_ids: torch.LongTensor,
cu_seqlens: torch.IntTensor,
max_seqlen: int,
) -> torch.Tensor:
return self.vision_model(
packed_pixel_values=packed_pixel_values,
packed_flattened_position_ids=packed_flattened_position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
"""
Utility for TaylorSeer
"""
# Adapted from https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-xDiT/taylorseer_flux/taylorseer_utils/__init__.py
import math
from typing import Dict
import torch
def derivative_approximation(cache_dic: Dict, current: Dict, feature: torch.Tensor):
"""
Compute derivative approximation.
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
difference_distance = (
current["activated_steps"][-1] - current["activated_steps"][-2]
)
updated_taylor_factors = {}
updated_taylor_factors[0] = feature
for i in range(cache_dic["max_order"]):
if (
cache_dic["cache"][-1][current["stream"]][current["layer"]][
current["module"]
].get(i, None)
is not None
) and (current["step"] > cache_dic["first_enhance"] - 2):
updated_taylor_factors[i + 1] = (
updated_taylor_factors[i]
- cache_dic["cache"][-1][current["stream"]][current["layer"]][
current["module"]
][i]
) / difference_distance
else:
break
cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]] = (
updated_taylor_factors
)
def taylor_formula(cache_dic: Dict, current: Dict) -> torch.Tensor:
"""
Compute Taylor expansion error.
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
x = current["step"] - current["activated_steps"][-1]
# x = current['t'] - current['activated_times'][-1]
output = 0
for i in range(
len(
cache_dic["cache"][-1][current["stream"]][current["layer"]][
current["module"]
]
)
):
output += (
(1 / math.factorial(i))
* cache_dic["cache"][-1][current["stream"]][current["layer"]][
current["module"]
][i]
* (x**i)
)
return output
def taylor_cache_init(cache_dic: Dict, current: Dict):
"""
Initialize Taylor cache and allocate storage for different-order derivatives in the Taylor cache.
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
if (current["step"] == 0) and (cache_dic["taylor_cache"]):
cache_dic["cache"][-1][current["stream"]][current["layer"]][
current["module"]
] = {}
# Copied from https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-xDiT/taylorseer_flux/cache_functions/force_scheduler.py
def force_scheduler(cache_dic, current):
if cache_dic["fresh_ratio"] == 0:
# FORA
linear_step_weight = 0.0
else:
# TokenCache
linear_step_weight = 0.0
step_factor = torch.tensor(
1
- linear_step_weight
+ 2 * linear_step_weight * current["step"] / current["num_steps"]
)
threshold = torch.round(cache_dic["fresh_threshold"] / step_factor)
# no force constrain for sensitive steps, cause the performance is good enough.
# you may have a try.
cache_dic["cal_threshold"] = threshold
# return threshold
# Copied from https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-xDiT/taylorseer_flux/cache_functions/cal_type.py
def cal_type(cache_dic, current):
"""
Determine calculation type for this step
"""
if (cache_dic["fresh_ratio"] == 0.0) and (not cache_dic["taylor_cache"]):
# FORA:Uniform
first_step = current["step"] == 0
else:
# ToCa: First enhanced
first_step = current["step"] < cache_dic["first_enhance"]
if not first_step:
fresh_interval = cache_dic["cal_threshold"]
else:
fresh_interval = cache_dic["fresh_threshold"]
if (first_step) or (cache_dic["cache_counter"] == fresh_interval - 1):
current["type"] = "full"
cache_dic["cache_counter"] = 0
current["activated_steps"].append(current["step"])
force_scheduler(cache_dic, current)
elif cache_dic["taylor_cache"]:
cache_dic["cache_counter"] += 1
current["type"] = "Taylor"
elif (
cache_dic["cache_counter"] % 2 == 1
): # 0: ToCa-Aggresive-ToCa, 1: Aggresive-ToCa-Aggresive
cache_dic["cache_counter"] += 1
current["type"] = "ToCa"
# 'cache_noise' 'ToCa' 'FORA'
elif cache_dic["Delta-DiT"]:
cache_dic["cache_counter"] += 1
current["type"] = "Delta-Cache"
else:
cache_dic["cache_counter"] += 1
current["type"] = "ToCa"
# Modified from https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-xDiT/taylorseer_flux/cache_functions/cache_init.py
def cache_init(self, num_steps: int):
"""
Initialization for cache.
"""
cache_dic = {}
cache = {}
cache_index = {}
cache[-1] = {}
cache_index[-1] = {}
cache_index["layer_index"] = {}
cache[-1]["layers_stream"] = {}
cache_dic["cache_counter"] = 0
for j in range(len(self.language_model.model.layers)):
cache[-1]["layers_stream"][j] = {}
cache_index[-1][j] = {}
cache_dic["Delta-DiT"] = False
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.0
cache_dic["fresh_threshold"] = 3
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["taylor_cache"] = True
cache_dic["max_order"] = 6
cache_dic["first_enhance"] = 5
current = {}
current["activated_steps"] = [0]
current["step"] = 0
current["num_steps"] = num_steps
return cache_dic, current
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment