Commit 2a934cec authored by raojy's avatar raojy
Browse files

first

parent 4b618aa3
{"id": "image1", "image": "./examples/vqa/data/images/image1.jpg", "question": "Describe the image content."}
{"id": "image2", "image": "./examples/vqa/data/images/image2.jpg", "question": "How many dogs are there in this image?"}
{"id": "image3", "image": "./examples/vqa/data/images/image3.jpg", "question": "Which pill do you choose?"}
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
import torch
import sensenova_u1
from sensenova_u1.models.neo_unify.utils import load_image_native
from sensenova_u1.utils import (
DEFAULT_IMAGE_PATCH_SIZE,
DEFAULT_VRAM_MODE,
InferenceProfiler,
add_offload_args,
best_available_device,
infer_input_device,
load_model_and_tokenizer,
make_offload_ctx,
vram_mode_to_prefetch_count,
)
class SenseNovaU1VQA:
"""Thin wrapper for visual understanding / VQA inference."""
def __init__(
self,
model_path: str,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
gguf_checkpoint: str | None = None,
device_map: str | None = None,
max_memory: str | None = None,
vram_mode: str = DEFAULT_VRAM_MODE,
) -> None:
self.vram_mode = vram_mode
self.prefetch_count = vram_mode_to_prefetch_count(vram_mode)
self.model, self.tokenizer = load_model_and_tokenizer(
model_path,
dtype=dtype,
device=device,
gguf_checkpoint=gguf_checkpoint,
for_offload=self.prefetch_count > 0,
device_map=device_map,
max_memory=max_memory,
)
self.device = str(infer_input_device(self.model, fallback=device)) if device_map else device
@torch.inference_mode()
def answer(
self,
image,
question: str,
history: list | None = None,
max_new_tokens: int = 1024,
do_sample: bool = False,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: int | None = None,
repetition_penalty: float | None = None,
) -> tuple[str, list]:
pixel_values, grid_hw = load_image_native(image)
pixel_values = pixel_values.to(self.device, dtype=self.model.dtype)
grid_hw = grid_hw.to(self.device)
generation_config = dict(
max_new_tokens=max_new_tokens,
do_sample=do_sample,
)
if do_sample:
generation_config["temperature"] = temperature
generation_config["top_p"] = top_p
if top_k is not None:
generation_config["top_k"] = top_k
if repetition_penalty is not None:
generation_config["repetition_penalty"] = repetition_penalty
with make_offload_ctx(self.model, self.prefetch_count, self.device) as offloaded:
response, updated_history = offloaded.chat(
self.tokenizer,
pixel_values,
question,
generation_config,
history=history,
return_history=True,
grid_hw=grid_hw,
)
return response, updated_history
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Visual understanding / VQA inference for SenseNova-U1.")
p.add_argument(
"--model_path",
required=True,
help="HuggingFace Hub id (e.g. sensenova/SenseNova-U1-8B-MoT) or a local path.",
)
src = p.add_mutually_exclusive_group(required=True)
src.add_argument("--image", help="Path to a single image file.")
src.add_argument(
"--jsonl",
help='JSONL file, one sample per line. Required fields: {"image": ..., "question": ...}. '
'Optional: {"id": ...}.',
)
p.add_argument("--question", help="Question to ask about the image (used with --image).")
p.add_argument("--output", default=None, help="Output file for single-image result (default: stdout).")
p.add_argument("--output_dir", default="outputs", help="Output directory when using --jsonl.")
p.add_argument("--max_new_tokens", type=int, default=1024)
p.add_argument("--do_sample", action="store_true", help="Enable sampling (default: greedy).")
p.add_argument("--temperature", type=float, default=0.7)
p.add_argument("--top_p", type=float, default=0.9)
p.add_argument("--top_k", type=int, default=None, help="Top-k sampling (default: None).")
p.add_argument("--repetition_penalty", type=float, default=None, help="Repetition penalty (default: None).")
p.add_argument(
"--device",
default=str(best_available_device()),
help="Compute device, e.g. 'cuda', 'cuda:0', 'xpu', 'xpu:0', 'cpu'. Defaults to the best available accelerator.",
)
p.add_argument(
"--dtype",
default="bfloat16",
choices=["bfloat16", "float16", "float32"],
)
add_offload_args(p)
p.add_argument(
"--gguf_checkpoint",
default=None,
help=(
"Optional path to a .gguf quantized checkpoint. When set, the dequantizing "
"diffusers GGUF Linear layer is used instead of safetensors weights. "
"Requires the [gguf] extra (gguf>=0.10.0, diffusers>=0.30.0)."
),
)
p.add_argument(
"--attn_backend",
default="auto",
choices=["auto", "flash", "sdpa"],
help=(
"Attention kernel used by the Qwen3 layers. 'auto' picks flash-attn when importable and falls back to SDPA."
),
)
p.add_argument(
"--profile",
action="store_true",
help=(
"Print timing and CUDA memory stats: model load time, average "
"per-image inference time, peak GPU memory, and the same time "
f"normalized per image token (patch size = {DEFAULT_IMAGE_PATCH_SIZE})."
),
)
return p.parse_args()
def main() -> None:
args = parse_args()
if args.image is not None and args.question is None:
print("[error] --question is required when using --image", file=sys.stderr)
sys.exit(1)
dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype]
sensenova_u1.set_attn_backend(args.attn_backend)
print(f"[attn] backend={args.attn_backend!r} (effective={sensenova_u1.effective_attn_backend()!r})")
profiler = InferenceProfiler(
enabled=args.profile,
device=args.device,
config={
"vram_mode": args.vram_mode,
"attn_backend": sensenova_u1.effective_attn_backend(),
"dtype": args.dtype,
"gguf": args.gguf_checkpoint,
},
)
with profiler.time_load():
engine = SenseNovaU1VQA(
args.model_path,
device=args.device,
dtype=dtype,
gguf_checkpoint=args.gguf_checkpoint,
device_map=args.device_map,
max_memory=args.max_memory,
vram_mode=args.vram_mode,
)
if args.image is not None:
# single image mode — image size used as proxy for profiler dimensions
img_path = Path(args.image)
with profiler.time_generate(width=1, height=1, batch=1):
response, _ = engine.answer(
img_path,
args.question,
max_new_tokens=args.max_new_tokens,
do_sample=args.do_sample,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
)
if args.output:
out = Path(args.output)
out.parent.mkdir(parents=True, exist_ok=True)
out.write_text(response)
print(f"[saved] {out}")
else:
print(response)
profiler.report()
return
# batch JSONL mode
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
with open(args.jsonl) as f:
samples = [json.loads(line) for line in f if line.strip()]
try:
from tqdm import tqdm
except ImportError:
def tqdm(x, **_kw): # type: ignore[no-redef]
return x
results = []
for sample in tqdm(samples, desc="VQA"):
img_path = Path(sample["image"])
question = sample["question"]
with profiler.time_generate(width=1, height=1, batch=1):
response, _ = engine.answer(
img_path,
question,
max_new_tokens=args.max_new_tokens,
do_sample=args.do_sample,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
)
result = {"id": sample.get("id", ""), "image": str(img_path), "question": question, "answer": response}
results.append(result)
print(f"[{result['id'] or '?'}] {response[:80]}{'...' if len(response) > 80 else ''}")
out_file = out_dir / "answers.jsonl"
with open(out_file, "w") as f:
for r in results:
f.write(json.dumps(r, ensure_ascii=False) + "\n")
print(f"[saved] {out_file}")
profiler.report()
if __name__ == "__main__":
main()
[project]
name = "sensenova-u1"
version = "0.1.0"
description = "SenseNova-U1: Unifying Multimodal Understanding and Generation with NEO-Unify Architecture."
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.11,<3.12"
authors = [
{ name = "OpenSenseNova" },
]
keywords = [ python examples/vqa/inference.py --model_path /public/home/raojy/project/model_code/novauuu --image examples/vqa/data/images/menu.jpg --question "My friend and I are dining together tonight. Looking at this menu, can you recommend a good combination of dishes for 2 people? We want a balanced meal — a mix of mains and maybe a starter or dessert. Budget-conscious but want to try the highlights." --output outputs/answer.txt --max_new_tokens 8192 --do_sample --temperature 0.6 --top_p 0.95 --top_k 20 --repetition_penalty 1.05 --profile
/public/home/raojy/project/model_code/novauuu
python examples/t2i/inference.py --model_path /public/home/raojy/project/model_code/novauuu --prompt "这张信息图的标题是“SenseNova-U1”,采用现代极简科技矩阵风格。整体布局为水平三列网格结构,背景是带有极浅银灰色细密点阵的哑光纯白高级纸张纹理,画面长宽比为16:9。\n\n排版采用严谨的视觉层级:主标题使用粗体无衬线黑体字,正文使用清晰的现代等宽字体。配色方案极其克制,以纯白色为底,深炭黑为主视觉文字和边框,浅石板灰用于背景色块和次要信息区分,图标采用精致的银灰色线框绘制。\n\n在画面正上方居中位置,使用醒目的深炭黑粗体字排布着大标题“SenseNova-U1”。标题正下方是浅石板灰色的等宽字体副标题“新一代端到端统一多模态大模型家族”。\n\n画面主体分为左、中、右三个相等的垂直信息区块,区块之间通过充足的负空间进行物理隔离。\n\n左侧区块的主题是概述。顶部有一个银灰色线框绘制的、由放大镜和齿轮交织的图标,旁边是粗体小标题“Overview”。该区块内从上到下垂直排列着三个要点:第一个要点旁边是一个代表文档与照片重叠的极简图标,紧跟着文字“多模态模型家族,统一文本/图像理解和生成”。向下是由两个相连的同心圆组成的架构图标,配有文字“基于NEO-Unify架构(端到端统一理解和生成)”。最下方是一个带有斜线划掉的眼睛和漏斗形状的图标,明确指示文本“无需视觉编码器(VE)和变分自编码器(VAE)”。\n\n中间区块展示模型矩阵。顶部是一个包含两个分支节点的树状网络图标,旁边是粗体小标题“两个模型规格”。区块内分为上下两个包裹在浅石板灰色极细边框内的卡片。上方的卡片内画着一个代表高密度的实心几何立方体图标,大字标注“SenseNova-U1-8B-MoT”,下方是等宽字体说明“8B MoT 密集主干模型”。下方的卡片内画着一个带有闪电符号的网状发光大脑图标,大字标注“SenseNova-U1-A3B-MoT”,下方是等宽字体说明“A3B MoT 混合专家(MoE)主干模型”。在这两个独立卡片的正下方,左侧放置一个笑脸轮廓图标搭配文字“将在HF等平台公开”,右侧放置一个带有折角的书面报告图标搭配文字“将发布技术报告”。\n\n右侧区块呈现核心优势。顶部是一个代表巅峰的上升阶梯折线图图标,旁边是粗体小标题“Highlights”。该区块内部垂直分布着四个带有浅石板灰底色的长方形色块,每个色块内部左侧对应一个具体的图标,右侧为文字。第一个色块内是一个无缝相连的莫比乌斯环图标,配文“原生统一架构,无VE和VAE”。第二个色块内是一个顶端带有星星的奖杯图标,配文“单一统一模型在理解和生成任务上均达到SOTA性能”。第三个色块内是代表文本行与拍立得照片交替穿插的图标,配文“强大的原生交错推理能力(模型原生生成图像进行推理)”。最后一个色块内是一个被切分出一小块的硬币与详细饼状图结合的图标,配文“能生成复杂信息图表,性价比出色”。" --width 2720 --height 1536 --cfg_scale 4.0 --cfg_norm none --timestep_shift 3.0 --num_steps 50 --output output.png --profile
"multimodal",
"vision-language-model",
"unified-model",
"understanding-and-generation",
"native-vlm",
]
classifiers = [
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.11",
"Operating System :: POSIX :: Linux",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
# Core dependencies.
#
# Versions here are pinned to match the internal reference environment
# (`miniconda3/envs/lmms_engine`, Python 3.11, CUDA 12.8)
# so that checkpoints trained there can be loaded and run bit-exactly. When
# bumping any of these, please re-verify inference on a reference checkpoint.
dependencies = [
"transformers==4.57.1",
"tokenizers==0.22.1",
"accelerate==1.10.1",
"huggingface-hub==0.36.2",
"safetensors==0.6.2",
"sentencepiece==0.2.1",
# NumPy is API-stable across 1.24..2.x for the call sites in sensenova_u1
# (np.arange/concatenate/cos/sin/einsum/meshgrid/stack/zeros and dtypes
# np.float32/64 — none of the numpy 2 removals like np.float_, np.cast,
# np.product are touched). Use a range so downstream consumers (notably
# ComfyUI environments shipping numpy 1.26.x) don't hit a dependency
# resolver conflict against an exact pin.
"numpy>=1.24,<3",
"pillow==12.0.0",
"tqdm==4.67.1",
"packaging==25.0",
# HTTP client used by `sensenova_u1.prompt_enhance.adapters`. Version is left
# open because only the high-level AsyncClient is used and the API has been
# stable since 0.24.
"httpx>=0.27,<1",
"pre-commit>=4.5.1",
]
[project.optional-dependencies]
# Optional flash-attention acceleration. When absent, the model transparently
# falls back to torch SDPA; see src/sensenova_u1/models/neo_unify/modeling_qwen3.py.
#
# The reference env uses a cp311 / torch 2.8 wheel (`flash_attn 2.8.3`). PyPI
# does not host CUDA-specific wheels, so users will typically install this
# manually from a local .whl that matches their torch + Python combination, e.g.:
#
# uv pip install /path/to/flash_attn-2.8.3+cu12torch28cxx11abitrue-cp311-cp311-*.whl
flash = [
"flash-attn>=2.8,<3",
]
gguf = [
"gguf>=0.10.0",
"diffusers>=0.30.0",
]
dev = [
"ruff",
"pytest",
"pre-commit",
]
[project.urls]
Homepage = "https://github.com/OpenSenseNova/SenseNova-U1"
Repository = "https://github.com/OpenSenseNova/SenseNova-U1"
[project.scripts]
sensenova-u1 = "sensenova_u1:main"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/sensenova_u1"]
# -----------------------------------------------------------------------------
# uv configuration: pin torch / torchvision to the PyTorch cu128 index.
#
# This makes `uv sync` reproducibly install the CUDA 12.8 build of
# torch==2.8.0 (+cu128), matching the reference env exactly. If your NVIDIA
# driver is too old for CUDA 12.8, swap the URL below to e.g.
# `https://download.pytorch.org/whl/cu126` and adjust torch/torchvision
# versions accordingly.
# -----------------------------------------------------------------------------
[tool.uv.sources]
torch = { index = "pytorch-cu128" }
torchvision = { index = "pytorch-cu128" }
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true
[tool.ruff]
line-length = 120
target-version = "py311"
extend-exclude = [
"src/sensenova_u1/models/neo_unify",
# Vendored verbatim from SenseNova-Skills; keep diff-clean for easy uprev.
"src/sensenova_u1/prompt_enhance/adapters",
# Third-party submodules — keep diff-clean from upstream.
"evaluation/easi/EASI",
"evaluation/easi/lightllm-stack/LightLLM",
]
[tool.ruff.lint]
select = ["I"]
from __future__ import annotations
import argparse
import torch
from sensenova_u1.utils import ModelParamInspector, build_rules, format_bytes, format_param_count
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Count actual model parameters and split by functional groups. "
"Default groups are tuned for SenseNova-U1-8B-MoT."
)
)
parser.add_argument(
"--model_path",
required=True,
help="Local checkpoint path or HuggingFace model id.",
)
parser.add_argument(
"--dtype",
default="bfloat16",
choices=("float32", "float16", "bfloat16"),
help=(
"Load dtype. bfloat16 by default to align with inference scripts. "
"Note: dtype only affects load-time memory; param counts are dtype-independent."
),
)
parser.add_argument(
"--custom_groups_json",
default=None,
help=('Optional JSON file to override grouping rules. Format: {"group_name": ["prefix1", "prefix2"]}.'),
)
parser.add_argument(
"--show_groups",
default="shared",
help=(
"Comma-separated group names whose member parameters will be listed in detail. "
"Use 'all' for every group, or empty string to disable. Default: shared."
),
)
return parser.parse_args()
def main() -> None:
args = parse_args()
dtype_map = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
inspector = ModelParamInspector(
model_path=args.model_path,
dtype=dtype_map[args.dtype],
)
rules = build_rules(args.custom_groups_json)
result = inspector.count(rules)
name_w = 28
width = name_w + 1 + 12 + 1 + 16 + 1 + 10 # = 69
dtype_label = args.dtype
memory_header = f"memory ({dtype_label})"
print(f"Model: {result.model_path}")
print(f"Load dtype: {dtype_label}")
print(f"Total params: {format_param_count(result.total_params)}")
print(f"Total memory: {format_bytes(result.total_bytes)} ({dtype_label})")
print("-" * width)
print(f"{'group':<{name_w}} {'params':>12} {memory_header:>16} {'ratio':>10}")
print("-" * width)
for group in result.groups:
ratio = (group.params / result.total_params) * 100.0 if result.total_params else 0.0
print(
f"{group.name:<{name_w}} {format_param_count(group.params):>12} "
f"{format_bytes(group.bytes):>16} {ratio:>9.2f}%"
)
_print_pathway_summary(result, width=width, name_w=name_w, memory_header=memory_header)
_print_group_entries(result, args.show_groups, width=width, dtype_label=dtype_label)
def _print_pathway_summary(result, *, width: int, name_w: int, memory_header: str) -> None:
"""Forward-pathway coverage: parameters touched when running each task.
Each pathway sums its dedicated transformer with the shared text I/O.
Both embed_tokens and lm_head are exercised by both tasks: the latter is
used by t2i-reasoning during the thinking phase before image tokens are
emitted, so it is a real shared component, not understanding-only.
"""
by_name = {g.name: g for g in result.groups}
required = ("understanding_transformer", "generation_transformer", "shared")
if not all(k in by_name for k in required):
return
shared = by_name["shared"]
pathways = (
(
"understanding pathway",
by_name["understanding_transformer"].params + shared.params,
by_name["understanding_transformer"].bytes + shared.bytes,
),
(
"generation pathway",
by_name["generation_transformer"].params + shared.params,
by_name["generation_transformer"].bytes + shared.bytes,
),
)
print("-" * width)
print("Pathway breakdown (shared counted in both):")
print("-" * width)
print(f"{'pathway':<{name_w}} {'params':>12} {memory_header:>16} {'ratio':>10}")
print("-" * width)
for name, params, nbytes in pathways:
ratio = (params / result.total_params) * 100.0 if result.total_params else 0.0
print(f"{name:<{name_w}} {format_param_count(params):>12} {format_bytes(nbytes):>16} {ratio:>9.2f}%")
def _print_group_entries(result, show_groups_arg: str, *, width: int, dtype_label: str) -> None:
"""Dump member parameters for the requested groups."""
show_groups_arg = (show_groups_arg or "").strip()
if not show_groups_arg:
return
by_name = {g.name: g for g in result.groups}
if show_groups_arg.lower() == "all":
target_names = [g.name for g in result.groups]
else:
target_names = [n.strip() for n in show_groups_arg.split(",") if n.strip()]
for group_name in target_names:
group = by_name.get(group_name)
if group is None:
print()
print(f"[show_groups] group '{group_name}' not found, skipped.")
continue
print()
print("-" * width)
print(
f"Members of group '{group.name}' "
f"({len(group.entries)} params, {format_param_count(group.params)} total, "
f"{format_bytes(group.bytes)} @ {dtype_label})"
)
print("-" * width)
print(f"{'param name':<54} {'numel':>10} {'dtype':>8}")
print("-" * width)
for entry in group.entries:
print(f"{entry.name:<54} {format_param_count(entry.numel):>10} {entry.dtype:>8}")
if __name__ == "__main__":
main()
from __future__ import annotations
from importlib import metadata as _metadata
from typing import Any
from .models.neo_unify import (
NEOChatConfig,
NEOChatModel,
NEOLLMConfig,
NEOMoELLMConfig,
NEOVisionConfig,
NEOVisionModel,
effective_attn_backend,
get_attn_backend,
has_flash_attn,
set_attn_backend,
)
from .models.neo_unify import (
register as _register,
)
try:
__version__ = _metadata.version("sensenova-u1")
except _metadata.PackageNotFoundError: # pragma: no cover - editable / not installed
__version__ = "0.1.0"
__all__ = [
"__version__",
"NEOChatConfig",
"NEOLLMConfig",
"NEOMoELLMConfig",
"NEOVisionConfig",
"NEOChatModel",
"NEOVisionModel",
"check_checkpoint_compatibility",
"set_attn_backend",
"get_attn_backend",
"effective_attn_backend",
"has_flash_attn",
"main",
]
def check_checkpoint_compatibility(config_or_dict: Any) -> None:
"""Raise ``RuntimeError`` if the installed ``sensenova_u1`` is too old for the checkpoint.
The checkpoint can advertise a minimum package version by setting
``sensenova_u1_min_version`` in its ``config.json``. If the field is
absent, no check is performed. This lets us evolve the modeling code
in git while keeping old checkpoints loadable, and hard-fail with a
clear message when a newer checkpoint requires a newer package.
"""
try:
from packaging.version import Version
except ImportError: # pragma: no cover
return
if hasattr(config_or_dict, "to_dict"):
cfg: dict = config_or_dict.to_dict()
elif isinstance(config_or_dict, dict):
cfg = config_or_dict
else:
return
required = cfg.get("sensenova_u1_min_version")
if not required:
return
if Version(__version__) < Version(str(required)):
raise RuntimeError(
f"This checkpoint requires sensenova-u1 >= {required}, "
f"but the installed version is {__version__}. "
f"Please upgrade with `uv sync` or `pip install -U sensenova-u1`."
)
_register()
def main() -> None:
print(f"SenseNova-U1 v{__version__}")
from __future__ import annotations
from . import neo_unify # noqa: F401 (re-export & register on import)
__all__ = ["neo_unify"]
from __future__ import annotations
from .configuration_neo_chat import NEOChatConfig, NEOLLMConfig, NEOMoELLMConfig
from .configuration_neo_vit import NEOVisionConfig
from .modeling_neo_chat import NEOChatModel
from .modeling_neo_vit import NEOVisionModel
from .modeling_qwen3 import (
_HAS_FLASH_ATTN as has_flash_attn,
effective_attn_backend,
get_attn_backend,
set_attn_backend,
)
from .modeling_qwen3 import Qwen3ForCausalLM
from .modeling_qwen3_moe import Qwen3MoeForCausalLM
__all__ = [
"NEOChatConfig",
"NEOLLMConfig",
"NEOMoELLMConfig",
"NEOVisionConfig",
"NEOChatModel",
"NEOVisionModel",
"Qwen3ForCausalLM",
"Qwen3MoeForCausalLM",
"register",
"set_attn_backend",
"get_attn_backend",
"effective_attn_backend",
"has_flash_attn",
]
_REGISTERED = False
def register() -> None:
"""Register NEO-Unify types with ``transformers.Auto*``.
After calling this (or simply ``import sensenova_u1``), users can load a
SenseNova-U1 checkpoint via plain ``AutoConfig.from_pretrained`` /
``AutoModel.from_pretrained``.
"""
global _REGISTERED
if _REGISTERED:
return
from transformers import AutoConfig, AutoModel
AutoConfig.register("neo_vision", NEOVisionConfig, exist_ok=True)
AutoConfig.register("neo_chat", NEOChatConfig, exist_ok=True)
AutoModel.register(NEOVisionConfig, NEOVisionModel, exist_ok=True)
AutoModel.register(NEOChatConfig, NEOChatModel, exist_ok=True)
_REGISTERED = True
import copy
from transformers import Qwen3Config, Qwen3MoeConfig
from transformers.utils import logging
from transformers.configuration_utils import PretrainedConfig
from .configuration_neo_vit import NEOVisionConfig
logger = logging.get_logger(__name__)
class NEOLLMConfig(Qwen3Config):
"""Config for the dense Qwen3 backbone used by NEO-Unify.
Extends ``Qwen3Config`` with two extra rope knobs used by the spatial
(height/width) rotary axes that are layered on top of the temporal one.
"""
def __init__(self, rope_theta_hw=10000.0, max_position_embeddings_hw=10000, **kwargs):
super().__init__(**kwargs)
self.rope_theta_hw = rope_theta_hw
self.max_position_embeddings_hw = max_position_embeddings_hw
class NEOMoELLMConfig(Qwen3MoeConfig):
"""Config for the Qwen3-MoE backbone used by NEO-Unify.
Extends ``Qwen3MoeConfig`` with the same ``rope_theta_hw`` /
``max_position_embeddings_hw`` extras as :class:`NEOLLMConfig`, and adds a
*generation-path* MoE branch alongside the standard understanding-path one.
In the A3B unified model every decoder layer carries two parallel sparse
MoE blocks routed by the per-token ``image_gen_indicators`` mask:
* ``mlp`` - sparse MoE for the understanding path
(``num_experts`` experts, ``num_experts_per_tok`` active,
expert width ``moe_intermediate_size``).
* ``mlp_mot_gen`` - sparse MoE for the image generation path
(``gen_num_experts`` experts, ``gen_num_experts_per_tok``
active, expert width ``gen_moe_intermediate_size``).
Each gen-path knob falls back to its understanding-path counterpart when
unset, so vanilla single-MoE configs keep working without changes.
"""
def __init__(
self,
rope_theta_hw=10000.0,
max_position_embeddings_hw=10000,
gen_num_experts=None,
gen_num_experts_per_tok=None,
gen_moe_intermediate_size=None,
**kwargs,
):
super().__init__(**kwargs)
self.rope_theta_hw = rope_theta_hw
self.max_position_embeddings_hw = max_position_embeddings_hw
# Generation-path MoE knobs default to the understanding-path values
# so legacy single-MoE configs (where both branches share the same
# router width / expert count) keep working unchanged.
self.gen_num_experts = (
int(gen_num_experts) if gen_num_experts is not None else int(self.num_experts)
)
self.gen_num_experts_per_tok = (
int(gen_num_experts_per_tok)
if gen_num_experts_per_tok is not None
else int(self.num_experts_per_tok)
)
self.gen_moe_intermediate_size = (
int(gen_moe_intermediate_size)
if gen_moe_intermediate_size is not None
else int(self.moe_intermediate_size)
)
# ``Qwen3Attention`` (used by NEO-Unify MoE layers) reads
# ``config.layer_types[layer_idx]`` to decide between ``"full_attention"``
# and ``"sliding_attention"``. Older / vanilla ``Qwen3MoeConfig`` does
# not populate that field, so we backfill it here mirroring the dense
# ``Qwen3Config`` behaviour: sliding-attention layers start at
# ``max_window_layers`` when ``use_sliding_window`` is enabled.
existing = getattr(self, "layer_types", None)
if not existing or len(existing) != self.num_hidden_layers:
use_swa = bool(getattr(self, "use_sliding_window", False)) and getattr(
self, "sliding_window", None
) is not None
max_window_layers = int(getattr(self, "max_window_layers", 0) or 0)
self.layer_types = [
"sliding_attention" if (use_swa and i >= max_window_layers) else "full_attention"
for i in range(self.num_hidden_layers)
]
def _is_moe_llm_config(llm_config) -> bool:
"""Detect whether an ``llm_config`` (dict or object) targets a MoE backbone.
Order of checks: explicit ``model_type``, ``architectures`` entry that
contains ``MoE/MoeForCausalLM``, or presence of MoE-specific keys
(``num_experts``).
"""
if isinstance(llm_config, dict):
model_type = llm_config.get("model_type", "")
archs = llm_config.get("architectures") or []
has_num_experts = "num_experts" in llm_config
else:
model_type = getattr(llm_config, "model_type", "")
archs = getattr(llm_config, "architectures", None) or []
has_num_experts = hasattr(llm_config, "num_experts")
if isinstance(model_type, str) and "moe" in model_type.lower():
return True
for arch in archs:
arch_str = str(arch)
if "Moe" in arch_str or "MoE" in arch_str:
return True
return bool(has_num_experts) and getattr(llm_config, "num_experts", 0) and int(getattr(llm_config, "num_experts", 0)) > 1
def _build_llm_config(llm_config):
"""Instantiate the right LLM config object from a dict or pre-built config."""
if isinstance(llm_config, dict):
if _is_moe_llm_config(llm_config):
return NEOMoELLMConfig(**llm_config)
return NEOLLMConfig(**llm_config)
return llm_config
class NEOChatConfig(PretrainedConfig):
model_type = 'neo_chat'
is_composition = True
def __init__(
self,
vision_config=None,
llm_config=None,
use_backbone_lora=0,
use_llm_lora=0,
downsample_ratio=0.5,
template=None,
**kwargs,
):
super().__init__(**kwargs)
if vision_config is None:
vision_config = {'architectures': ['NEOVisionModel']}
logger.info('vision_config is None. Initializing the NEOVisionConfig with default values.')
if llm_config is None:
llm_config = {'architectures': ['Qwen3ForCausalLM']}
logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).')
assert 'architectures' in llm_config, "Should specify architecture in llm_config"
if isinstance(vision_config, dict):
self.vision_config = NEOVisionConfig(**vision_config)
else:
self.vision_config = vision_config
self.llm_config = _build_llm_config(llm_config)
self.use_backbone_lora = use_backbone_lora
self.use_llm_lora = use_llm_lora
self.downsample_ratio = downsample_ratio
self.template = template
self.tie_word_embeddings = self.llm_config.tie_word_embeddings
@property
def is_moe_llm(self) -> bool:
"""Convenience flag so callers can switch between dense / MoE LLM."""
return isinstance(self.llm_config, NEOMoELLMConfig)
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output['vision_config'] = self.vision_config.to_dict()
output['llm_config'] = self.llm_config.to_dict()
output['model_type'] = self.__class__.model_type
output['use_backbone_lora'] = self.use_backbone_lora
output['use_llm_lora'] = self.use_llm_lora
output['downsample_ratio'] = self.downsample_ratio
output['template'] = self.template
return output
import os
from typing import Union
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class NEOVisionConfig(PretrainedConfig):
model_type = 'neo_vision'
def __init__(
self,
num_channels=3,
patch_size=16,
hidden_size=1024,
llm_hidden_size=2048,
downsample_ratio=0.5,
rope_theta_vision=10000.0,
max_position_embeddings_vision=10000,
min_pixels=65536,
max_pixels=4194304,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.llm_hidden_size = llm_hidden_size,
self.downsample_ratio = downsample_ratio,
self.rope_theta_vision = rope_theta_vision
self.max_position_embeddings_vision = max_position_embeddings_vision
self.num_channels = num_channels
self.patch_size = patch_size
self.min_pixels = min_pixels
self.max_pixels = max_pixels
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if 'vision_config' in config_dict:
config_dict = config_dict['vision_config']
if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
)
return cls.from_dict(config_dict, **kwargs)
\ No newline at end of file
"""
Conversation prompt templates.
We kindly request that you import fastchat instead of copying this file if you wish to use it.
If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates.
Modified from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
"""
import dataclasses
from enum import IntEnum, auto
from typing import Dict, List, Tuple, Union
class SeparatorStyle(IntEnum):
"""Separator styles."""
ADD_COLON_SINGLE = auto()
ADD_COLON_TWO = auto()
ADD_COLON_SPACE_SINGLE = auto()
NO_COLON_SINGLE = auto()
NO_COLON_TWO = auto()
ADD_NEW_LINE_SINGLE = auto()
LLAMA2 = auto()
CHATGLM = auto()
CHATML = auto()
CHATINTERN = auto()
DOLLY = auto()
RWKV = auto()
PHOENIX = auto()
ROBIN = auto()
FALCON_CHAT = auto()
CHATGLM3 = auto()
INTERNVL_ZH = auto()
MPT = auto()
@dataclasses.dataclass
class Conversation:
"""A class that manages prompt templates and keeps all conversation history."""
# The name of this template
name: str
# The template of the system prompt
system_template: str = '{system_message}'
# The system message
system_message: str = ''
# The names of two roles
roles: Tuple[str] = ('USER', 'ASSISTANT')
# All messages. Each item is (role, message).
messages: List[List[str]] = ()
# The number of few shot examples
offset: int = 0
# The separator style and configurations
sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
sep: str = '\n'
sep2: str = None
# Stop criteria (the default one is EOS token)
stop_str: Union[str, List[str]] = None
# Stops generation if meeting any token in this list
stop_token_ids: List[int] = None
def get_prompt(self) -> str:
"""Get the prompt for generation."""
if self.system_message is not None and self.system_message != '':
system_prompt = self.system_template.format(system_message=self.system_message)
else:
system_prompt = ''
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
ret = '' if system_prompt == '' else system_prompt + self.sep
for role, message in self.messages:
if message:
ret += role + ': ' + message + self.sep
else:
ret += role + ':'
return ret
elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
seps = [self.sep, self.sep2]
ret = '' if system_prompt == '' else system_prompt + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ': ' + message + seps[i % 2]
else:
ret += role + ':'
return ret
elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
ret = '' if system_prompt == '' else system_prompt + self.sep
for role, message in self.messages:
if message:
ret += role + ': ' + message + self.sep
else:
ret += role + ': ' # must be end with a space
return ret
elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
ret = '' if system_prompt == '' else system_prompt + self.sep
for role, message in self.messages:
if message:
ret += role + '\n' + message + self.sep
else:
ret += role + '\n'
return ret
elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
ret = system_prompt
for role, message in self.messages:
if message:
ret += role + message + self.sep
else:
ret += role
return ret
elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
seps = [self.sep, self.sep2]
ret = system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + message + seps[i % 2]
else:
ret += role
return ret
elif self.sep_style == SeparatorStyle.RWKV:
ret = system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
ret += (
role
+ ': '
+ message.replace('\r\n', '\n').replace('\n\n', '\n')
)
ret += '\n\n'
else:
ret += role + ':'
return ret
elif self.sep_style == SeparatorStyle.LLAMA2:
seps = [self.sep, self.sep2]
ret = system_prompt if system_prompt != '' else '[INST] '
for i, (role, message) in enumerate(self.messages):
tag = self.roles[i % 2]
if message:
if i == 0:
ret += message + ' '
else:
ret += tag + ' ' + message + seps[i % 2]
else:
ret += tag
return ret
elif self.sep_style == SeparatorStyle.CHATGLM:
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
round_add_n = 1 if self.name == 'chatglm2' else 0
ret = '' if system_prompt == '' else system_prompt + self.sep
for i, (role, message) in enumerate(self.messages):
if i % 2 == 0:
ret += f'[Round {i//2 + round_add_n}]{self.sep}'
if message:
ret += f'{role}{message}{self.sep}'
else:
ret += f'{role}:'
return ret
elif self.sep_style == SeparatorStyle.CHATML:
ret = '' if system_prompt == '' else system_prompt + self.sep + '\n'
for role, message in self.messages:
if message:
ret += role + '\n' + message + self.sep + '\n'
else:
ret += role + '\n'
return ret
elif self.sep_style == SeparatorStyle.CHATGLM3:
ret = system_prompt
for role, message in self.messages:
if message:
ret += role + '\n' + ' ' + message
else:
ret += role
return ret
elif self.sep_style == SeparatorStyle.CHATINTERN:
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
seps = [self.sep, self.sep2]
ret = system_prompt
for i, (role, message) in enumerate(self.messages):
# if i % 2 == 0:
# ret += "<s>"
if message:
ret += role + ':' + message + seps[i % 2] + '\n'
else:
ret += role + ':'
return ret
elif self.sep_style == SeparatorStyle.DOLLY:
seps = [self.sep, self.sep2]
ret = system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ':\n' + message + seps[i % 2]
if i % 2 == 1:
ret += '\n\n'
else:
ret += role + ':\n'
return ret
elif self.sep_style == SeparatorStyle.PHOENIX:
ret = system_prompt
for role, message in self.messages:
if message:
ret += role + ': ' + '<s>' + message + '</s>'
else:
ret += role + ': ' + '<s>'
return ret
elif self.sep_style == SeparatorStyle.ROBIN:
ret = '' if system_prompt == '' else system_prompt + self.sep
for role, message in self.messages:
if message:
ret += role + ':\n' + message + self.sep
else:
ret += role + ':\n'
return ret
elif self.sep_style == SeparatorStyle.FALCON_CHAT:
ret = '' if system_prompt == '' else system_prompt + self.sep
for role, message in self.messages:
if message:
ret += role + ': ' + message + self.sep
else:
ret += role + ':'
return ret
elif self.sep_style == SeparatorStyle.INTERNVL_ZH:
seps = [self.sep, self.sep2]
ret = '' if system_prompt == '' else self.system_message + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ': ' + message + seps[i % 2]
else:
ret += role + ':'
return ret
elif self.sep_style == SeparatorStyle.MPT:
ret = '' if system_prompt == '' else system_prompt + self.sep
for i, (role, message) in enumerate(self.messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + message + self.sep
else:
if i != len(self.messages) and message is not None:
ret += role + self.sep
else:
ret += role
return ret
else:
raise ValueError(f'Invalid style: {self.sep_style}')
def set_system_message(self, system_message: str):
"""Set the system message."""
self.system_message = system_message
def append_message(self, role: str, message: str):
"""Append a new message."""
self.messages.append([role, message])
def update_last_message(self, message: str):
"""Update the last output.
The last message is typically set to be None when constructing the prompt,
so we need to update it in-place after getting the response from a model.
"""
self.messages[-1][1] = message
def to_gradio_chatbot(self):
"""Convert the conversation to gradio chatbot format."""
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def to_openai_api_messages(self):
"""Convert the conversation to OpenAI chat completion format."""
ret = [{'role': 'system', 'content': self.system_message}]
for i, (_, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append({'role': 'user', 'content': msg})
else:
if msg is not None:
ret.append({'role': 'assistant', 'content': msg})
return ret
def copy(self):
return Conversation(
name=self.name,
system_template=self.system_template,
system_message=self.system_message,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
stop_str=self.stop_str,
stop_token_ids=self.stop_token_ids,
)
def dict(self):
return {
'template_name': self.name,
'system_message': self.system_message,
'roles': self.roles,
'messages': self.messages,
'offset': self.offset,
}
# A global registry for all conversation templates
conv_templates: Dict[str, Conversation] = {}
def register_conv_template(template: Conversation, override: bool = False):
"""Register a new conversation template."""
if not override:
assert (
template.name not in conv_templates
), f'{template.name} has been registered.'
conv_templates[template.name] = template
def get_conv_template(name: str) -> Conversation:
"""Get a conversation template."""
return conv_templates[name].copy()
# Both Hermes-2 and neo1_0-chat are chatml-format conversation templates. The difference
# is that during training, the preprocessing function for the Hermes-2 template doesn't add
# <s> at the beginning of the tokenized sequence, while the neo1_0-chat template does.
# Therefore, they are completely equivalent during inference.
register_conv_template(
Conversation(
name='Hermes-2',
system_template='<|im_start|>system\n{system_message}',
# note: The new system prompt was not used here to avoid changes in benchmark performance.
# system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
sep_style=SeparatorStyle.MPT,
sep='<|im_end|>',
stop_str='<|endoftext|>',
)
)
register_conv_template(
Conversation(
name='internlm2-chat',
system_template='<|im_start|>system\n{system_message}',
# note: The new system prompt was not used here to avoid changes in benchmark performance.
# system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
sep_style=SeparatorStyle.MPT,
sep='<|im_end|>',
)
)
register_conv_template(
Conversation(
name='phi3-chat',
system_template='<|system|>\n{system_message}',
# note: The new system prompt was not used here to avoid changes in benchmark performance.
# system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
roles=('<|user|>\n', '<|assistant|>\n'),
sep_style=SeparatorStyle.MPT,
sep='<|end|>',
)
)
register_conv_template(
Conversation(
name='internvl2_5',
system_template='<|im_start|>system\n{system_message}',
system_message='你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
sep_style=SeparatorStyle.MPT,
sep='<|im_end|>\n',
)
)
register_conv_template(
Conversation(
name='neo1_0',
system_template='<|im_start|>system\n{system_message}',
system_message='',
roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
sep_style=SeparatorStyle.MPT,
sep='<|im_end|>\n',
)
)
import numpy as np
import torch
import torch.nn as nn
import math
from functools import lru_cache
from torch.utils.checkpoint import checkpoint
def modulate(x, shift, scale=None):
if shift is None:
return x * (1 + scale)
return x * (1 + scale) + shift
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return output * self.weight
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t: torch.Tensor, dim: int, max_period: float = 10000.0):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element. These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=t.device
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
return t_emb
class ResBlock(nn.Module):
def __init__(self, channels, mlp_ratio=1.0):
super().__init__()
self.channels = channels
self.intermediate_size = int(channels * mlp_ratio)
self.in_ln = nn.LayerNorm(self.channels, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(self.channels, self.intermediate_size),
nn.SiLU(),
nn.Linear(self.intermediate_size, self.channels),
)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True))
def forward(self, x, y):
shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
h = self.mlp(h)
return x + gate_mlp * h
# class FinalLayer(nn.Module):
# def __init__(self, model_channels, out_channels):
# super().__init__()
# self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
# self.linear = nn.Linear(model_channels, out_channels, bias=True)
# self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(model_channels, 2 * model_channels, bias=True))
# def forward(self, x, c):
# shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
# x = modulate(self.norm_final(x), shift, scale)
# x = self.linear(x)
# return x
# class SimpleMLPAdaLN(nn.Module):
# def __init__(self, input_dim, out_dim, dim=1536, layers=12, mlp_ratio=1.0):
# super().__init__()
# self.input_dim = input_dim
# self.out_dim = out_dim
# self.dim = dim
# self.layers = layers
# self.mlp_ratio = mlp_ratio
# self.time_embed = TimestepEmbedder(dim)
# self.input_proj = nn.Linear(input_dim, dim)
# res_blocks = []
# for _ in range(layers):
# res_blocks.append(ResBlock(dim, mlp_ratio))
# self.res_blocks = nn.ModuleList(res_blocks)
# self.final_layer = FinalLayer(dim, out_dim)
# self.grad_checkpointing = False
# self.initialize_weights()
# def initialize_weights(self):
# def _basic_init(module):
# if isinstance(module, nn.Linear):
# torch.nn.init.xavier_uniform_(module.weight)
# if module.bias is not None:
# nn.init.constant_(module.bias, 0)
# self.apply(_basic_init)
# # Initialize timestep embedding MLP
# nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
# nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
# # Zero-out adaLN modulation layers
# for block in self.res_blocks:
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# # Zero-out output layers
# nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
# nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
# nn.init.constant_(self.final_layer.linear.weight, 0)
# nn.init.constant_(self.final_layer.linear.bias, 0)
# def forward(self, x, t):
# """
# x.shape = (bsz, input_dim)
# t.shape = (bsz,)
# """
# x = self.input_proj(x)
# t = self.time_embed(t)
# y = t
# for block in self.res_blocks:
# if self.grad_checkpointing and self.training:
# x = checkpoint(block, x, y, use_reentrant=True)
# else:
# x = block(x, y)
# return self.final_layer(x, y)
class FlowMatchingHead(nn.Module):
def __init__(self, input_dim, out_dim, dim=1536, layers=12, mlp_ratio=1.0):
super(FlowMatchingHead, self).__init__()
self.net = SimpleMLPAdaLN(input_dim=input_dim, out_dim=out_dim, dim=dim, layers=layers, mlp_ratio=mlp_ratio)
@property
def dtype(self):
return self.net.input_proj.weight.dtype
@property
def device(self):
return self.net.input_proj.weight.device
def forward(self, x, t):
x = self.net(x, t)
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
class NerfEmbedder(nn.Module):
def __init__(self, in_channels, hidden_size_input, max_freqs):
super().__init__()
self.max_freqs = max_freqs
self.hidden_size_input = hidden_size_input
self.embedder = nn.Sequential(
nn.Linear(in_channels+max_freqs**2, hidden_size_input, bias=True),
)
@lru_cache
def fetch_pos(self, patch_size, device, dtype):
pos = precompute_freqs_cis_2d(self.max_freqs ** 2 * 2, patch_size, patch_size).real
pos = pos[None, :, :].to(device=device, dtype=dtype)
return pos
def forward(self, inputs):
B, P2, C = inputs.shape
patch_size = int(P2 ** 0.5)
device = inputs.device
dtype = inputs.dtype
dct = self.fetch_pos(patch_size, device, dtype)
dct = dct.repeat(B, 1, 1)
inputs = torch.cat([inputs, dct], dim=-1)
inputs = self.embedder(inputs)
return inputs
class SimpleMLPAdaLN(nn.Module):
"""
The MLP for Diffusion Loss.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param z_channels: channels in the condition.
:param num_res_blocks: number of residual blocks per downsample.
"""
def __init__(
self,
in_channels,
model_channels,
out_channels,
z_channels,
num_res_blocks,
patch_size,
grad_checkpointing=False
):
super().__init__()
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.grad_checkpointing = grad_checkpointing
self.patch_size = patch_size
self.cond_embed = nn.Linear(z_channels, patch_size**2*model_channels)
self.input_proj = nn.Linear(in_channels, model_channels)
res_blocks = []
for i in range(num_res_blocks):
res_blocks.append(ResBlock(
model_channels,
))
self.res_blocks = nn.ModuleList(res_blocks)
self.final_layer = FinalLayer(model_channels, out_channels)
self.initialize_weights()
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Zero-out adaLN modulation layers
for block in self.res_blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, c):
"""
Apply the model to an input batch.
:param x: an [N x C] Tensor of inputs.
:param t: a 1-D batch of timesteps.
:param c: conditioning from AR transformer.
:return: an [N x C] Tensor of outputs.
"""
x = self.input_proj(x)
c = self.cond_embed(c)
y = c.reshape(-1, self.patch_size**2, self.model_channels)
for block in self.res_blocks:
x = block(x, y)
return self.final_layer(x)
class FinalLayer(nn.Module):
"""
The final layer adopted from DiT.
"""
def __init__(self, model_channels, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(model_channels, out_channels, bias=True)
def forward(self, x):
x = self.norm_final(x)
x = self.linear(x)
return x
#################################################################################
# Sine/Cosine Positional Embedding Functions #
#################################################################################
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, pe_interpolation=1.0):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32) / pe_interpolation
grid_w = np.arange(grid_size, dtype=np.float32) / pe_interpolation
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
# --------------------------------------------------------
# Interpolate position embeddings for high-resolution
# References:
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
def interpolate_pos_embed(model_path, pe_key: str = "gen_pos_embed", new_len: int = 4096):
state_dict = torch.load(model_path, map_location="cpu")
pos_embed_1d = state_dict[pe_key]
_, ori_len, embed_dim = pos_embed_1d.shape
ori_size = int(ori_len**0.5)
new_size = int(new_len**0.5)
if ori_size != new_size:
logger.info("Position interpolate from %dx%d to %dx%d" % (ori_size, ori_size, new_size, new_size))
pos_embed_2d = pos_embed_1d.reshape(-1, ori_size, ori_size, embed_dim).permute(0, 3, 1, 2)
pos_embed_2d = torch.nn.functional.interpolate(
pos_embed_2d, size=(new_size, new_size), mode="bicubic", align_corners=False
)
pos_embed_1d = pos_embed_2d.permute(0, 2, 3, 1).flatten(1, 2)
state_dict[pe_key] = pos_embed_1d
torch.save(state_dict, model_path)
class PositionEmbedding(nn.Module):
def __init__(self, max_num_patch_per_side, hidden_size):
super().__init__()
self.max_num_patch_per_side = max_num_patch_per_side
self.hidden_size = hidden_size
self.pos_embed = nn.Parameter(
torch.zeros(max_num_patch_per_side ** 2, hidden_size),
requires_grad=False
)
self._init_weights()
def _init_weights(self):
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed = get_2d_sincos_pos_embed(self.hidden_size, self.max_num_patch_per_side)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float())
def forward(self, position_ids):
return self.pos_embed[position_ids]
class ResidualConvBlock(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
nn.SiLU(),
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
)
nn.init.zeros_(self.block[2].weight)
nn.init.zeros_(self.block[2].bias)
def forward(self, x):
return x + self.block(x)
class PostConvSmoother(nn.Module):
def __init__(self, in_channels=3, hidden_channels=64, num_blocks=3):
super().__init__()
self.in_proj = nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1)
self.blocks = nn.Sequential(*[ResidualConvBlock(hidden_channels) for _ in range(num_blocks)])
self.out_proj = nn.Conv2d(hidden_channels, in_channels, kernel_size=1)
nn.init.zeros_(self.out_proj.weight)
nn.init.zeros_(self.out_proj.bias)
def forward(self, x):
h = self.in_proj(x)
h = self.blocks(h)
return x + self.out_proj(h)
class ProgressiveConvDecoder(nn.Module):
def __init__(self, hidden_dim=4096, out_channels=3):
super().__init__()
# self.proj = nn.Linear(hidden_dim, 1024)
# self.act = nn.SiLU()
self.up_blocks = nn.ModuleList([
nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(hidden_dim, 512, kernel_size=3, padding=1),
nn.GroupNorm(32, 512),
nn.SiLU()
),
nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(512, 256, kernel_size=3, padding=1),
nn.GroupNorm(32, 256),
nn.SiLU()
),
nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(256, 64, kernel_size=3, padding=1),
nn.GroupNorm(32, 64),
nn.SiLU()
),
nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.GroupNorm(16, 32),
nn.SiLU()
),
nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(32, 16, kernel_size=3, padding=1),
nn.SiLU()
)
])
self.out_conv = nn.Conv2d(16, out_channels, kernel_size=3, padding=1)
def forward(self, x_2d):
# B, C, H, W = x_2d.shape
# x = x_2d.permute(0, 2, 3, 1).contiguous() # (B, H, W, C)
# x = self.proj(x)
# x = self.act(x)
# x = x.permute(0, 3, 1, 2).contiguous() # (B, 512, H, W)
x = x_2d
for block in self.up_blocks:
x = block(x)
out = self.out_conv(x)
return out
class PatchDecoder_postps(nn.Module):
def __init__(self):
super().__init__()
# layer 1: H/32 -> H/8 (4x upscale)
self.conv1 = nn.Conv2d(4096, 4096, kernel_size=3, padding=1)
self.ps1 = nn.PixelShuffle(4)
self.act1 = nn.GELU()
# layer 2: H/8 -> H (8x upscale)
self.conv2 = nn.Conv2d(256, 192, kernel_size=3, padding=1)
self.ps2 = nn.PixelShuffle(8)
def forward(self, x):
# x shape: [B, 4096, H/32, W/32]
x = self.ps1(self.act1(self.conv1(x))) # -> [B, 256, H/8, W/8]
x = self.ps2(self.conv2(x)) # -> [B, 3, H, W]
return x
class PatchDecoder_preps(nn.Module):
def __init__(self):
super().__init__()
# layer 1: H/32 -> H/16 (2x upscale)
self.ps1 = nn.PixelShuffle(2)
self.conv1 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1)
self.act1 = nn.GELU()
# layer 2: H/16 -> H/8 (2x upscale)
self.ps2 = nn.PixelShuffle(2)
self.conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
self.act2 = nn.GELU()
# layer 3: H/8 -> H (8x upscale)
self.ps3 = nn.PixelShuffle(8)
self.conv3 = nn.Conv2d(4, 3, kernel_size=3, padding=1)
def forward(self, x):
# x shape: [B, 4096, H/32, W/32]
x = self.act1(self.conv1(self.ps1((x)))) # -> [B, 256, H/16, W/16]
x = self.act2(self.conv2(self.ps2((x)))) # -> [B, 256, H/8, W/8]
x = self.conv3(self.ps3((x))) # -> [B, 3, H, W]
return x
class PatchDecoder_preps1(nn.Module):
def __init__(self):
super().__init__()
# layer 1: H/32 -> H/16 (2x upscale)
self.ps1 = nn.PixelShuffle(2)
self.conv1 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1)
self.act1 = nn.GELU()
# layer 2: H/16 -> H/8 (2x upscale)
self.ps2 = nn.PixelShuffle(2)
self.conv2 = nn.Conv2d(256, 192, kernel_size=3, padding=1)
# layer 3: H/8 -> H (8x upscale)
self.ps3 = nn.PixelShuffle(8)
def forward(self, x):
# x shape: [B, 4096, H/32, W/32]
x = self.act1(self.conv1(self.ps1((x)))) # -> [B, 256, H/16, W/16]
x = self.ps3(self.conv2(self.ps2((x)))) # -> [B, 256, H/8, W/8]
return x
class ConvDecoder(nn.Module):
def __init__(self, input_dim=4096, hidden_dim=1024):
super().__init__()
# layer 1: H/32 -> H/16 (2x upscale)
self.ps1 = nn.PixelShuffle(2)
self.conv1 = nn.Conv2d(input_dim // 4, hidden_dim, kernel_size=3, padding=1)
self.act1 = nn.GELU()
# layer 2: H/16 -> H/8 (2x upscale)
self.ps2 = nn.PixelShuffle(2)
self.conv2 = nn.Conv2d(hidden_dim // 4, 192, kernel_size=3, padding=1)
# layer 3: H/8 -> H (8x upscale)
self.ps3 = nn.PixelShuffle(8)
def forward(self, x):
x = self.act1(self.conv1(self.ps1((x))))
x = self.ps3(self.conv2(self.ps2((x))))
return x
from typing import List, Optional, Tuple, Union
import math
import os
import torch.utils.checkpoint
from torch import nn
import transformers
from torch.nn import CrossEntropyLoss
from transformers import GenerationConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from .configuration_neo_chat import NEOChatConfig, NEOMoELLMConfig
from .conversation import get_conv_template
from .modeling_neo_vit import NEOVisionModel
from .modeling_qwen3 import Qwen3ForCausalLM, create_block_causal_mask
from .modeling_qwen3_moe import Qwen3MoeForCausalLM
from .modeling_fm_modules import PositionEmbedding, TimestepEmbedder, FlowMatchingHead, RMSNorm, NerfEmbedder, SimpleMLPAdaLN, ConvDecoder
from .utils import load_image_native, SYSTEM_MESSAGE_FOR_GEN
logger = logging.get_logger(__name__)
def version_cmp(v1, v2, op='eq'):
import operator
from packaging import version
op_func = getattr(operator, op)
return op_func(version.parse(v1), version.parse(v2))
def prepare_flash_kv_cache(
past_key_values,
current_len: int,
batch_size: int,
):
"""
Convert prefix cache from [B, H, S, D] to flash-attn friendly [B, S, H, D],
and preallocate full KV buffer for [prefix + current].
This is done once before denoising loop.
"""
if past_key_values is None:
return
for layer in past_key_values.layers:
past_k = layer.keys
past_v = layer.values
if past_k is None or past_v is None:
layer.flash_prefix_len = 0
layer.flash_total_len = current_len
layer.flash_k_cache = None
layer.flash_v_cache = None
continue
# original cache layout assumed: [B, H, S, D]
past_k_flash = past_k.transpose(1, 2).contiguous() # [B, S, H, D]
past_v_flash = past_v.transpose(1, 2).contiguous() # [B, S, H, D]
prefix_len = past_k_flash.shape[1]
total_len = prefix_len + current_len
k_cache = torch.empty(
(batch_size, total_len, past_k_flash.shape[2], past_k_flash.shape[3]),
device=past_k_flash.device,
dtype=past_k_flash.dtype,
)
v_cache = torch.empty(
(batch_size, total_len, past_v_flash.shape[2], past_v_flash.shape[3]),
device=past_v_flash.device,
dtype=past_v_flash.dtype,
)
k_cache[:, :prefix_len].copy_(past_k_flash)
v_cache[:, :prefix_len].copy_(past_v_flash)
layer.flash_prefix_len = prefix_len
layer.flash_total_len = total_len
layer.flash_k_cache = k_cache
layer.flash_v_cache = v_cache
def clear_flash_kv_cache(past_key_values):
if past_key_values is None:
return
for layer in past_key_values.layers:
if hasattr(layer, "flash_prefix_len"):
delattr(layer, "flash_prefix_len")
if hasattr(layer, "flash_total_len"):
delattr(layer, "flash_total_len")
if hasattr(layer, "flash_k_cache"):
delattr(layer, "flash_k_cache")
if hasattr(layer, "flash_v_cache"):
delattr(layer, "flash_v_cache")
def optimized_scale(positive_flat, negative_flat):
# Force the divisor computation to float32 regardless of the surrounding
# autocast (the squared-norm/division is what we don't want in fp16/bf16).
# ``device_type`` is taken from the input so this runs equally on CUDA and
# XPU; ``mps`` is rerouted to ``cpu`` because torch.autocast rejects it.
device_type = positive_flat.device.type
if device_type == "mps":
device_type = "cpu"
with torch.autocast(device_type=device_type, enabled=False):
positive_flat = positive_flat.float()
negative_flat = negative_flat.float()
# Calculate dot production
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
# Squared norm of uncondition
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
st_star = dot_product / squared_norm
return st_star
def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None):
"""
Compute patch coordinates (x, y)
Args:
grid_hw: (B, 2) tensor representing (H, W) per image
"""
device = grid_hw.device
B = grid_hw.shape[0]
# Get the number of patches per image
H = grid_hw[:, 0]
W = grid_hw[:, 1]
N = H * W
N_total = N.sum()
# Create the batch index for each patch (B x patch count)
patch_to_sample = torch.repeat_interleave(torch.arange(B, device=device), N) # (N_total,)
# Generate intra-image patch index (row-major order)
patch_id_within_image = torch.arange(N_total, device=device)
patch_id_within_image = patch_id_within_image - torch.cumsum(
torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0
)[patch_to_sample]
# Get H/W for each patch according to its image
W_per_patch = W[patch_to_sample]
abs_x = patch_id_within_image % W_per_patch
abs_y = patch_id_within_image // W_per_patch
return abs_x, abs_y
class NEOChatModel(PreTrainedModel):
config_class = NEOChatConfig
main_input_name = 'pixel_values'
base_model_prefix = 'language_model'
_supports_flash_attn_2 = True
supports_gradient_checkpointing = True
_no_split_modules = [
"NEOVisionModel",
"Qwen3DecoderLayer",
"Qwen3MoeDecoderLayer",
]
# support transformers 4.51.+
_tp_plan = ''
def __init__(self, config: NEOChatConfig, vision_model=None, language_model=None, use_flash_attn=True):
super().__init__(config)
assert version_cmp(transformers.__version__, '4.37.0', 'ge')
patch_size = config.vision_config.patch_size
self.patch_size = patch_size
self.template = config.template
self.downsample_ratio = config.downsample_ratio
config.llm_config._attn_implementation = 'eager'
if vision_model is not None:
self.vision_model = vision_model
else:
self.vision_model = NEOVisionModel(config.vision_config)
vision_model_mot_gen = NEOVisionModel(config.vision_config)
if language_model is not None:
self.language_model = language_model
else:
# Pick the right backbone class based on the LLM config: dense
# Qwen3 (DANCE family) or Qwen3-MoE (A3B family). The two share
# the same NEO-Unify two-branch attention/norm layout, so the
# rest of this class works against either.
if isinstance(config.llm_config, NEOMoELLMConfig):
self.language_model = Qwen3MoeForCausalLM(config.llm_config)
else:
self.language_model = Qwen3ForCausalLM(config.llm_config)
merge_size = int(1 / self.downsample_ratio)
output_dim = 3*(patch_size*merge_size)**2
llm_hidden_size = self.config.llm_config.hidden_size
self.use_deep_fm_head = self.config.fm_head_layers > 2
self.use_pixel_head = self.config.use_pixel_head
if self.use_deep_fm_head:
fm_head = FlowMatchingHead(llm_hidden_size, output_dim, dim=self.config.fm_head_dim, layers=self.config.fm_head_layers, mlp_ratio=self.config.fm_head_mlp_ratio)
else:
fm_head = nn.Sequential(
nn.Linear(llm_hidden_size, 4096, bias=True),
nn.GELU(),
nn.Linear(4096, output_dim, bias=True),
)
timestep_embedder = TimestepEmbedder(llm_hidden_size)
self.fm_modules = nn.ModuleDict(
{
"vision_model_mot_gen": vision_model_mot_gen,
"timestep_embedder": timestep_embedder,
"fm_head": fm_head
}
)
if self.use_pixel_head:
self.fm_modules["fm_head"] = ConvDecoder(llm_hidden_size)
self.concat_time_token_num = config.concat_time_token_num
self.noise_scale = config.noise_scale
self.noise_scale_mode = config.noise_scale_mode
self.noise_scale_base_image_seq_len = config.noise_scale_base_image_seq_len
self.add_noise_scale_embedding = config.add_noise_scale_embedding
self.noise_scale_max_value = config.noise_scale_max_value
self.time_schedule = config.time_schedule
self.time_shift_type = config.time_shift_type
self.base_shift = config.base_shift
self.max_shift = config.max_shift
self.base_image_seq_len = config.base_image_seq_len
self.max_image_seq_len = config.max_image_seq_len
if self.add_noise_scale_embedding:
noise_scale_embedder = TimestepEmbedder(llm_hidden_size)
self.fm_modules['noise_scale_embedder'] = noise_scale_embedder
self.img_context_token_id = None
self.img_start_token_id = 151670
self.last_think_content = ""
self.conv_template = get_conv_template(self.template)
self.system_message = self.conv_template.system_message
def forward(
self,
pixel_values: torch.FloatTensor,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
image_flags: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
raise NotImplementedError('forward')
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
image_flags = image_flags.squeeze(-1)
input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
vit_embeds = self.extract_feature(pixel_values)
vit_embeds = vit_embeds[image_flags == 1]
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
# if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
# print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
input_ids = input_ids.reshape(B * N)
selected = (input_ids == self.img_context_token_id)
try:
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
except Exception as e:
vit_embeds = vit_embeds.reshape(-1, C)
print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
f'vit_embeds.shape={vit_embeds.shape}')
n_token = min(selected.sum(), vit_embeds.size(0))
input_embeds[selected][:n_token] = input_embeds[selected][:n_token] * 0.0 + vit_embeds[:n_token]
input_embeds = input_embeds.reshape(B, N, C)
outputs = self.language_model(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = outputs.logits
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def extract_feature(self, pixel_values, gen_model=False, grid_hw=None):
if gen_model:
return self.fm_modules['vision_model_mot_gen'](pixel_values=pixel_values,
output_hidden_states=False,
return_dict=True,
grid_hw=grid_hw).last_hidden_state
else:
return self.vision_model(pixel_values=pixel_values,
output_hidden_states=False,
return_dict=True,
grid_hw=grid_hw).last_hidden_state
def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
raise NotImplementedError('batch_chat')
if history is not None or return_history:
print('Now multi-turn chat is not supported in batch_chat.')
raise NotImplementedError
if image_counts is not None:
num_patches_list = image_counts
print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
self.img_context_token_id = img_context_token_id
if verbose and pixel_values is not None:
image_bs = pixel_values.shape[0]
print(f'dynamic ViT batch size: {image_bs}')
queries = []
for idx, num_patches in enumerate(num_patches_list):
question = questions[idx]
if pixel_values is not None and '<image>' not in question:
question = '<image>\n' + question
template = get_conv_template(self.template)
template.system_message = self.system_message
template.append_message(template.roles[0], question)
template.append_message(template.roles[1], None)
query = template.get_prompt()
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN + IMG_END_TOKEN
query = query.replace('<image>', image_tokens, 1)
queries.append(query)
tokenizer.padding_side = 'left'
model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
input_ids = model_inputs['input_ids'].to(self.device)
attention_mask = model_inputs['attention_mask'].to(self.device)
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
generation_config['eos_token_id'] = eos_token_id
generation_output = self.generate(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
**generation_config
)
responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
responses = [response.split(template.sep.strip())[0].strip() for response in responses]
return responses
def patchify(self, images, patch_size, channel_first=False):
"""
images: (N, 3, H, W)
x: (N, L, patch_size**2 *3)
"""
h, w = images.shape[2] // patch_size, images.shape[3] // patch_size
x = images.reshape(shape=(images.shape[0], 3, h, patch_size, w, patch_size))
if channel_first:
x = torch.einsum('nchpwq->nhwcpq', x)
else:
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(images.shape[0], h * w, patch_size**2 * 3))
return x
def unpatchify(sle, x, patch_size, h=None, w=None):
"""
x: (N, L, patch_size**2 *3)
images: (N, 3, H, W)
"""
if h is None or w is None:
h = w = int(x.shape[1]**.5)
else:
h = h // patch_size
w = w // patch_size
x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, 3))
x = torch.einsum('nhwpqc->nchpwq', x)
images = x.reshape(shape=(x.shape[0], 3, h * patch_size, w * patch_size))
return images
def _euler_step(self, v_pred, z, t, t_next):
z_next = z + (t_next - t) * v_pred
return z_next
def _calculate_dynamic_mu(self, image_seq_len: int) -> float:
denom = self.max_image_seq_len - self.base_image_seq_len
if denom == 0:
return float(self.base_shift)
m = (self.max_shift - self.base_shift) / denom
b = self.base_shift - m * self.base_image_seq_len
return float(image_seq_len) * m + b
def _apply_time_schedule(self, t: torch.Tensor, image_seq_len: int, timestep_shift: float) -> torch.Tensor:
self.time_schedule = "standard"
sigma = 1 - t
if timestep_shift != 1:
self.time_schedule = "standard"
if self.time_schedule == "standard":
shift = timestep_shift
sigma = shift * sigma / (1 + (shift - 1) * sigma)
elif self.time_schedule == "dynamic":
mu = self._calculate_dynamic_mu(image_seq_len)
mu_t = t.new_tensor(mu)
if self.time_shift_type == "exponential":
shift = torch.exp(mu_t)
sigma = shift * sigma / (1 + (shift - 1) * sigma)
elif self.time_shift_type == "linear":
sigma = mu_t / (mu_t + (1 / sigma - 1))
else:
raise ValueError(f"Unsupported time_shift_type: {self.time_shift_type}")
else:
raise ValueError(f"Unsupported time_schedule: {self.time_schedule}")
return 1 - sigma
def _build_t2i_query(self, prompt_text, system_message=None, append_text=None):
template = get_conv_template(self.template)
template.system_message = self.system_message if system_message is None else system_message
template.append_message(template.roles[0], prompt_text)
template.append_message(template.roles[1], None)
if append_text is not None:
return template.get_prompt() + append_text
return template.get_prompt()
def _build_t2i_text_inputs(self, tokenizer, query: str):
model_inputs = tokenizer(query, return_tensors="pt")
input_ids = model_inputs["input_ids"].to(self.device)
t_idx = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
h_idx = torch.zeros_like(t_idx)
w_idx = torch.zeros_like(t_idx)
indexes = torch.stack([t_idx, h_idx, w_idx], dim=0)
attention_mask = {"full_attention": create_block_causal_mask(indexes[0])}
return input_ids, indexes, attention_mask
def _build_t2i_image_indexes(self, token_h, token_w, text_len, device):
t_image = torch.full((token_h * token_w,), text_len, dtype=torch.long, device=device)
idx = torch.arange(token_h * token_w, device=device, dtype=torch.long)
h_image = idx // token_w
w_image = idx % token_w
return torch.stack([t_image, h_image, w_image], dim=0)
def _t2i_prefix_forward(self, input_ids, indexes, attention_mask):
out = self.language_model.model(
input_ids=input_ids,
indexes=indexes,
attention_mask=attention_mask,
use_cache=True,
)
return out.past_key_values, out.last_hidden_state
def _it2i_prefix_forward(self, input_imbeds, indexes, attention_mask, gen_indicators=None):
out = self.language_model.model(
inputs_embeds=input_imbeds,
indexes=indexes,
attention_mask=attention_mask,
use_cache=True,
image_gen_indicators=gen_indicators.view(1, -1) if gen_indicators is not None else None
)
return out.past_key_values, out.last_hidden_state
def _append_text_tokens_to_cache(self, cache, t_idx, input_ids):
if input_ids.shape[1] == 0:
return t_idx
device = input_ids.device
seq_len = input_ids.shape[1]
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
t_indexes = torch.arange(t_idx + 1, t_idx + 1 + seq_len, dtype=torch.long, device=device)
h_indexes = torch.zeros(seq_len, dtype=torch.long, device=device)
w_indexes = torch.zeros(seq_len, dtype=torch.long, device=device)
indexes = torch.stack([t_indexes, h_indexes, w_indexes], dim=0)
past_len = cache.get_seq_length()
mask = torch.zeros(1, 1, seq_len, past_len + seq_len, device=device)
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
causal_mask = torch.where(causal_mask == 1, 0.0, float('-inf'))
mask[:, :, :, past_len:] = causal_mask
attention_mask_dict = {"full_attention": mask}
self.language_model(
inputs_embeds=inputs_embeds,
indexes=indexes,
attention_mask=attention_mask_dict,
past_key_values=cache,
use_cache=True
)
return t_idx + seq_len
def _generate_think(
self,
tokenizer,
prefix_outputs,
past_key_values,
t_idx,
IMG_START_TOKEN,
max_think_tokens=1024,
):
template = get_conv_template(self.template)
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
think_end_token_id = tokenizer.convert_tokens_to_ids('</think>')
think_token_ids = []
next_token = torch.argmax(prefix_outputs.logits[:, -1, :], dim=-1)
for _ in range(max_think_tokens):
token_item = next_token.item()
if token_item == eos_token_id:
break
if token_item == think_end_token_id:
self.language_model.model.current_index = t_idx
outputs = self.language_model(
input_ids=next_token.unsqueeze(0),
past_key_values=past_key_values,
use_cache=True
)
past_key_values = outputs.past_key_values
t_idx += 1
think_token_ids.append(token_item)
break
think_token_ids.append(token_item)
self.language_model.model.current_index = t_idx
outputs = self.language_model(
input_ids=next_token.unsqueeze(0),
past_key_values=past_key_values,
use_cache=True
)
past_key_values = outputs.past_key_values
t_idx += 1
next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1)
append_ids = tokenizer(
'\n\n' + IMG_START_TOKEN,
return_tensors='pt',
add_special_tokens=False,
)['input_ids'].to(self.device)
t_idx = self._append_text_tokens_to_cache(past_key_values, t_idx, append_ids)
think_text = tokenizer.decode(think_token_ids, skip_special_tokens=False)
return past_key_values, t_idx, think_text
def _t2i_predict_v(self, input_embeds, indexes_image, attn_mask, past_key_values, t, z, image_token_num, timestep_embeddings=None, image_size=None):
B, L = z.shape[0], z.shape[1]
outputs = self.language_model.model(
inputs_embeds=input_embeds,
image_gen_indicators=torch.ones((input_embeds.shape[0], input_embeds.shape[1]), dtype=torch.bool, device=input_embeds.device),
indexes=indexes_image,
attention_mask=attn_mask,
past_key_values=past_key_values,
update_cache=False,
use_cache=True,
)
if self.use_pixel_head:
merge_size = int(1 / self.downsample_ratio)
token_h = image_size[1] // (self.patch_size * merge_size)
token_w = image_size[0] // (self.patch_size * merge_size)
img_reshaped = outputs.last_hidden_state[:, -image_token_num:].view(B, token_h, token_w, -1)
img_2d = torch.einsum("b h w c -> b c h w", img_reshaped)
img_2d = img_2d.contiguous().view(B, -1, token_h, token_w)
smoothed_img_2d = self.fm_modules['fm_head'](img_2d)
smoothed_reshaped = smoothed_img_2d.view(B, 3, token_h, self.patch_size * merge_size, token_w, self.patch_size * merge_size)
smoothed_reshaped = torch.einsum("b c h p w q -> b h w p q c", smoothed_reshaped)
out_1d = smoothed_reshaped.contiguous().view(B, L, self.patch_size * merge_size * self.patch_size * merge_size * 3)
x_pred = out_1d
else:
if self.use_deep_fm_head:
x_pred = self.fm_modules["fm_head"](
outputs.last_hidden_state[:, -image_token_num:].view(B*L, -1), t.repeat(B*L)
).view(B, L, -1)
else:
x_pred = self.fm_modules["fm_head"](
outputs.last_hidden_state[:, -image_token_num:].view(B, L, -1)
).view(B, L, -1)
v_pred = (x_pred - z) / (1 - t).clamp_min(self.config.t_eps)
return v_pred
def _build_it2i_inputs(self, tokenizer, query, pixel_values=None, grid_hw=None):
model_inputs = tokenizer(query, return_tensors="pt")
input_ids = model_inputs["input_ids"].to(self.device)
indexes = self.get_thw_indexes(input_ids[0], grid_hw)
attention_mask = {"full_attention": create_block_causal_mask(indexes[0])}
input_embeds = self.language_model.get_input_embeddings()(input_ids)
B, N, C = input_embeds.shape
if pixel_values is not None:
vit_embeds = self.extract_feature(pixel_values, grid_hw=grid_hw)
input_embeds = input_embeds.reshape(B * N, C)
input_ids = input_ids.reshape(B * N)
selected = (input_ids == self.img_context_token_id)
assert selected.sum() != 0
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
input_embeds = input_embeds.reshape(B, N, C)
return input_embeds, indexes, attention_mask
@torch.no_grad()
def interleave_gen_image_only(
self,
tokenizer,
prompt,
gt_text,
images=None,
gt_images=None,
cfg_scale=1.0,
img_cfg_scale=1.0,
cfg_norm='none',
max_images=10,
enable_timestep_shift=True,
timestep_shift=1.0,
image_size=(256, 256),
num_steps=30,
IMG_START_TOKEN='<img>',
IMG_END_TOKEN='</img>',
IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
method='euler',
cfg_interval=(0, 1),
t_eps=0.02,
verbose=False,
system_message='',
):
self.img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
self.img_start_token_id = tokenizer.convert_tokens_to_ids(IMG_START_TOKEN)
self.config.t_eps = t_eps
if isinstance(image_size, tuple):
image_size_list = [image_size] * max_images
elif isinstance(image_size, list) and isinstance(image_size[0], tuple):
image_size_list = image_size
if len(image_size) < max_images:
image_size_list += [image_size_list[-1]] * (max_images - len(image_size_list))
else:
assert False, "image size should be a tuple or a list of tuple"
if images is None:
images =[]
image_token_count = prompt.count('<image>')
assert len(images) >= image_token_count
if len(images) > image_token_count:
prompt = "<image>\n" * (len(images) - image_token_count) + prompt
pixel_values =[]
grid_hw =[]
for image in images:
cur_pixel_values, cur_grid_hw = load_image_native(image, self.patch_size, self.downsample_ratio, min_pixels=512*512, max_pixels=min(2048*2048, (4096*4096)//max(1, len(images))), upscale=False)
grid_hw.append(cur_grid_hw.to(self.device))
pixel_values.append(cur_pixel_values.to(self.device).to(torch.bfloat16))
merge_size = int(1 / self.downsample_ratio)
pv_tensor = torch.cat(pixel_values) if pixel_values else None
ghw_tensor = torch.cat(grid_hw) if grid_hw else None
# Condition Initial Cache
template_cond = get_conv_template(self.template)
template_cond.system_message = system_message
template_cond.append_message(template_cond.roles[0], prompt)
template_cond.append_message(template_cond.roles[1], None)
query_cond = template_cond.get_prompt() + '<think>\n\n</think>\n\n'
def replace_image_tokens(query, grid_hw_list):
for i in range(len(grid_hw_list)):
num_patch_token = int(grid_hw_list[i][0, 0] * grid_hw_list[i][0, 1] * self.downsample_ratio**2)
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * num_patch_token + IMG_END_TOKEN
query = query.replace('<image>', image_tokens, 1)
return query
query_cond = replace_image_tokens(query_cond, grid_hw)
input_embeds_cond, indexes_cond, attention_mask_cond = self._build_it2i_inputs(tokenizer, query_cond, pv_tensor, ghw_tensor)
outputs_cond = self.language_model(inputs_embeds=input_embeds_cond, indexes=indexes_cond, attention_mask=attention_mask_cond, use_cache=True)
past_key_values_cond = outputs_cond.past_key_values
t_index_cond = indexes_cond[0].max().item()
# Text Uncondition Cache Initial
question_text_uncondition = '<image>' * len(images)
template_tu = get_conv_template(self.template)
template_tu.system_message = self.system_message
template_tu.append_message(template_tu.roles[0], question_text_uncondition)
template_tu.append_message(template_tu.roles[1], None)
query_text_uncond = template_tu.get_prompt()
query_text_uncond = replace_image_tokens(query_text_uncond, grid_hw)
input_embeds_tu, indexes_tu, attention_mask_tu = self._build_it2i_inputs(tokenizer, query_text_uncond, pv_tensor, ghw_tensor)
outputs_tu = self.language_model(inputs_embeds=input_embeds_tu, indexes=indexes_tu, attention_mask=attention_mask_tu, use_cache=True)
past_key_values_tu = outputs_tu.past_key_values
t_index_tu = indexes_tu[0].max().item()
# Img Uncondition Cache Initial
query_img_uncond = self._build_t2i_query("", append_text=IMG_START_TOKEN)
input_embeds_iu, indexes_iu, attention_mask_iu = self._build_it2i_inputs(tokenizer, query_img_uncond)
outputs_iu = self.language_model(inputs_embeds=input_embeds_iu, indexes=indexes_iu, attention_mask=attention_mask_iu, use_cache=True)
past_key_values_iu = outputs_iu.past_key_values
generated_images =[]
img_count = 0
device = self.device
def append_ids_to_cache(cache, t_idx, input_ids):
if input_ids.shape[1] == 0:
return t_idx
seq_len = input_ids.shape[1]
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
t_indexes = torch.arange(t_idx + 1, t_idx + 1 + seq_len, dtype=torch.long, device=device)
h_indexes = torch.zeros(seq_len, dtype=torch.long, device=device)
w_indexes = torch.zeros(seq_len, dtype=torch.long, device=device)
indexes = torch.stack([t_indexes, h_indexes, w_indexes], dim=0)
past_len = cache.get_seq_length()
mask = torch.zeros(1, 1, seq_len, past_len + seq_len, device=device)
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
causal_mask = torch.where(causal_mask == 1, 0.0, float('-inf'))
mask[:, :, :, past_len:] = causal_mask
attention_mask_dict = {"full_attention": mask}
self.language_model(
inputs_embeds=inputs_embeds,
indexes=indexes,
attention_mask=attention_mask_dict,
past_key_values=cache,
use_cache=True
)
return t_idx + seq_len
def append_image_to_cache(cache, t_idx, inputs_embeds_img, N_img_tokens, abs_pos_w, abs_pos_h):
past_len = cache.get_seq_length()
tgt_len = N_img_tokens + 1
t_indexes = torch.zeros(tgt_len, dtype=torch.long, device=device)
t_indexes[:N_img_tokens] = t_idx + 1
t_indexes[N_img_tokens] = t_idx + 2
h_indexes = torch.zeros(tgt_len, dtype=torch.long, device=device)
w_indexes = torch.zeros(tgt_len, dtype=torch.long, device=device)
h_indexes[:N_img_tokens] = abs_pos_h
w_indexes[:N_img_tokens] = abs_pos_w
indexes = torch.stack([t_indexes, h_indexes, w_indexes], dim=0)
mask = torch.zeros(1, 1, tgt_len, past_len + tgt_len, device=device)
mask[0, 0, :N_img_tokens, past_len + N_img_tokens] = float('-inf')
attention_mask_dict = {"full_attention": mask}
self.language_model(
inputs_embeds=inputs_embeds_img,
indexes=indexes,
attention_mask=attention_mask_dict,
past_key_values=cache,
use_cache=True
)
return t_idx + 2
parts = gt_text.split('<image>')
img_start_id_tensor = torch.tensor([[self.img_start_token_id]], device=device)
for i, part in enumerate(parts):
if len(part) > 0:
if verbose:
print(part, end='', flush=True)
part_ids = tokenizer(part, return_tensors='pt', add_special_tokens=False)['input_ids'].to(device)
t_index_cond = append_ids_to_cache(past_key_values_cond, t_index_cond, part_ids)
if i < len(parts) - 1:
if img_count >= max_images:
break
if verbose:
print("<image>", end='', flush=True)
t_index_cond = append_ids_to_cache(past_key_values_cond, t_index_cond, img_start_id_tensor)
t_index_tu = append_ids_to_cache(past_key_values_tu, t_index_tu, img_start_id_tensor)
cur_image_size = image_size_list[img_count]
token_h = cur_image_size[1] // (self.patch_size * merge_size)
token_w = cur_image_size[0] // (self.patch_size * merge_size)
indexes_image_condition = self._build_t2i_image_indexes(token_h, token_w, t_index_cond + 1, device=device)
indexes_image_text_uncondition = self._build_t2i_image_indexes(token_h, token_w, t_index_tu + 1, device=device)
indexes_image_img_uncondition = self._build_t2i_image_indexes(token_h, token_w, indexes_iu[0].max() + 1, device=device)
grid_h = cur_image_size[1] // self.patch_size
grid_w = cur_image_size[0] // self.patch_size
gen_grid_hw = torch.tensor([[grid_h, grid_w]], device=device)
noise_scale = self.noise_scale
if self.noise_scale_mode in ("resolution", "dynamic", 'dynamic_sqrt'):
noise_scale = math.sqrt((grid_h*grid_w)/(merge_size**2) / self.noise_scale_base_image_seq_len)
base = float(self.noise_scale_base_image_seq_len)
noise_scale = math.sqrt((grid_h*grid_w)/(merge_size**2)/base) * float(self.noise_scale)
if self.noise_scale_mode == 'dynamic_sqrt':
noise_scale = math.sqrt(noise_scale)
noise_scale = min(noise_scale, self.noise_scale_max_value)
image_prediction = noise_scale * torch.randn((1, 3, cur_image_size[1], cur_image_size[0]), device=device, dtype=outputs_cond.logits.dtype)
past_key_values_cond_cfg = past_key_values_cond
past_key_values_tu_cfg = past_key_values_tu
past_key_values_iu_cfg = past_key_values_iu
# attention_mask_condition = {"full_attention": torch.zeros(1, 1, token_h*token_w, past_key_values_cond.get_seq_length() + token_h*token_w, device=device)}
# attention_mask_text_uncondition = {"full_attention": torch.zeros(1, 1, token_h*token_w, past_key_values_tu.get_seq_length() + token_h*token_w, device=device)}
# attention_mask_img_uncondition = {"full_attention": torch.zeros(1, 1, token_h*token_w, past_key_values_iu.get_seq_length() + token_h*token_w, device=device)}
attention_mask_condition = {"full_attention": None}
attention_mask_text_uncondition = {"full_attention": None}
attention_mask_img_uncondition = {"full_attention": None}
prepare_flash_kv_cache(
past_key_values_cond_cfg,
current_len=token_h * token_w,
batch_size=1,
)
prepare_flash_kv_cache(
past_key_values_tu_cfg,
current_len=token_h * token_w,
batch_size=1,
)
prepare_flash_kv_cache(
past_key_values_iu_cfg,
current_len=token_h * token_w,
batch_size=1,
)
timesteps = torch.linspace(0.0, 1.0, num_steps+1, device=device)
if enable_timestep_shift:
timesteps = self._apply_time_schedule(timesteps, token_h*token_w, timestep_shift)
step_iter = range(num_steps)
if verbose:
try:
from tqdm import tqdm as _tqdm
step_iter = _tqdm(
step_iter,
desc=f"image {img_count + 1} ({image_size[0]}x{image_size[1]})",
total=num_steps,
leave=False,
)
except ImportError:
pass
for step_i in step_iter:
t = timesteps[step_i]
t_next = timesteps[step_i + 1]
z = self.patchify(image_prediction, self.patch_size * merge_size)
image_input = self.patchify(image_prediction, self.patch_size, channel_first=True)
image_embeds = self.extract_feature(image_input.view(1 * grid_h*grid_w, -1), gen_model=True, grid_hw=gen_grid_hw).view(1, token_h*token_w, -1)
t_expanded = t.expand(token_h*token_w)
timestep_embeddings = self.fm_modules['timestep_embedder'](t_expanded).view(1, token_h*token_w, -1)
if self.add_noise_scale_embedding:
noise_scale_tensor = torch.full_like(t_expanded, noise_scale/self.noise_scale_max_value)
noise_embeddings = self.fm_modules['noise_scale_embedder'](noise_scale_tensor).view(1, token_h*token_w, -1)
timestep_embeddings += noise_embeddings
image_embeds = image_embeds + timestep_embeddings
use_cfg = (t > cfg_interval[0] and t < cfg_interval[1]) or cfg_interval[0] == 0
out_cond = self._t2i_predict_v(image_embeds, indexes_image_condition, attention_mask_condition, past_key_values_cond_cfg, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings)
if not use_cfg:
v_pred = out_cond
elif cfg_scale == 1 and img_cfg_scale == 1:
v_pred = out_cond
elif img_cfg_scale == 1:
out_img_cond = self._t2i_predict_v(image_embeds, indexes_image_text_uncondition, attention_mask_text_uncondition, past_key_values_tu_cfg, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings)
v_pred = out_img_cond + cfg_scale * (out_cond - out_img_cond)
elif cfg_scale == img_cfg_scale:
out_uncond = self._t2i_predict_v(image_embeds, indexes_image_img_uncondition, attention_mask_img_uncondition, past_key_values_iu_cfg, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings)
v_pred = out_uncond + cfg_scale * (out_cond - out_uncond)
else:
out_img_cond = self._t2i_predict_v(image_embeds, indexes_image_text_uncondition, attention_mask_text_uncondition, past_key_values_tu_cfg, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings)
out_uncond = self._t2i_predict_v(image_embeds, indexes_image_img_uncondition, attention_mask_img_uncondition, past_key_values_iu_cfg, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings)
v_pred = (
out_uncond
+ cfg_scale * (out_cond - out_img_cond)
+ img_cfg_scale * (out_img_cond - out_uncond)
)
if (cfg_scale > 1 or img_cfg_scale > 1) and use_cfg:
if cfg_norm == 'global':
norm_v_condition = torch.norm(out_cond, dim=(1,2), keepdim=True)
norm_v_cfg = torch.norm(v_pred, dim=(1,2), keepdim=True)
scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0)
v_pred = v_pred * scale
elif cfg_norm == 'channel':
norm_v_condition = torch.norm(out_cond, dim=-1, keepdim=True)
norm_v_cfg = torch.norm(v_pred, dim=-1, keepdim=True)
scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0)
v_pred = v_pred * scale
z = z + (t_next - t) * v_pred
image_prediction = self.unpatchify(z, self.patch_size * merge_size, cur_image_size[1], cur_image_size[0])
generated_images.append(image_prediction)
clear_flash_kv_cache(past_key_values_cond_cfg)
clear_flash_kv_cache(past_key_values_tu_cfg)
clear_flash_kv_cache(past_key_values_iu_cfg)
if gt_images is not None and img_count < len(gt_images):
gt_img_pil = gt_images[img_count]
gt_pixel_values, gt_grid_hw = load_image_native(gt_img_pil, self.patch_size, self.downsample_ratio, min_pixels=512*512, max_pixels=(2048*2048), upscale=False)
gt_pixel_values = gt_pixel_values.to(device).to(torch.bfloat16)
flatten_pixel_values = gt_pixel_values
gen_grid_hw_und = gt_grid_hw
else:
pred_img = image_prediction[0].unsqueeze(0).to(torch.bfloat16)
raw_img = pred_img * 0.5 + 0.5
img_mean = torch.tensor([0.485, 0.456, 0.406], dtype=raw_img.dtype, device=device).view(1, 3, 1, 1)
img_std = torch.tensor([0.229, 0.224, 0.225], dtype=raw_img.dtype, device=device).view(1, 3, 1, 1)
und_img = (raw_img - img_mean) / img_std
c, h, w = und_img[0].shape
ps = self.patch_size
p_grid_h = h // ps
p_grid_w = w // ps
flatten_pixel_values = (
und_img[0].view(c, p_grid_h, ps, p_grid_w, ps)
.permute(1, 3, 0, 2, 4)
.reshape(p_grid_h * p_grid_w, c * ps ** 2)
)
gen_grid_hw_und = torch.tensor([[p_grid_h, p_grid_w]], device=device)
vit_embeds = self.extract_feature(flatten_pixel_values, grid_hw=gen_grid_hw_und[:1]).unsqueeze(0)
img_end_id = tokenizer.convert_tokens_to_ids(IMG_END_TOKEN)
img_end_embed = self.language_model.get_input_embeddings()(torch.tensor([[img_end_id]], device=device))
inputs_embeds_img = torch.cat([vit_embeds, img_end_embed], dim=1) # (1, N + 1, C)
N_img_tokens = vit_embeds.shape[1]
abs_pos_w, abs_pos_h = build_abs_positions_from_grid_hw(gen_grid_hw_und[:1] // int(1 / self.downsample_ratio), device=device)
t_index_cond = append_image_to_cache(past_key_values_cond, t_index_cond, inputs_embeds_img, N_img_tokens, abs_pos_w, abs_pos_h)
t_index_tu = append_image_to_cache(past_key_values_tu, t_index_tu, inputs_embeds_img, N_img_tokens, abs_pos_w, abs_pos_h)
img_count += 1
return generated_images
@torch.no_grad()
def interleave_gen(
self,
tokenizer,
prompt,
images=None,
generation_config=None,
cfg_scale=1.0,
img_cfg_scale=1.0,
cfg_norm='none',
max_images=10,
enable_timestep_shift=True,
timestep_shift=1.0,
image_size=(256, 256),
num_steps=30,
IMG_START_TOKEN='<img>',
IMG_END_TOKEN='</img>',
IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
method='euler',
cfg_interval=(0, 1),
t_eps=0.02,
verbose=False,
system_message='',
think_mode=False,
seed=0,
):
self.img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
self.img_start_token_id = tokenizer.convert_tokens_to_ids(IMG_START_TOKEN)
self.config.t_eps = t_eps
if isinstance(image_size, tuple):
image_size_list = [image_size] * max_images
elif isinstance(image_size, list) and isinstance(image_size[0], tuple):
image_size_list = image_size
if len(image_size) < max_images:
image_size_list += [image_size_list[-1]] * (max_images - len(image_size_list))
else:
assert False, "image size should be a tuple or a list of tuple"
if generation_config and hasattr(generation_config, 'max_new_tokens') and generation_config.max_new_tokens is not None:
max_new_tokens = generation_config.max_new_tokens
else:
max_new_tokens = 8192
current_generated_tokens = 0
if images is None:
images = []
template = get_conv_template(self.template)
template.system_message = self.system_message
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
image_token_count = prompt.count('<image>')
assert len(images) >= image_token_count
if len(images) > image_token_count:
prompt = "<image>\n" * (len(images) - image_token_count) + prompt
pixel_values =[]
grid_hw =[]
for image in images:
cur_pixel_values, cur_grid_hw = load_image_native(image, self.patch_size, self.downsample_ratio, min_pixels=512*512, max_pixels=min(2048*2048, (4096*4096)//max(1, len(images))), upscale=False)
grid_hw.append(cur_grid_hw.to(self.device))
pixel_values.append(cur_pixel_values.to(self.device).to(torch.bfloat16))
merge_size = int(1 / self.downsample_ratio)
pv_tensor = torch.cat(pixel_values) if pixel_values else None
ghw_tensor = torch.cat(grid_hw) if grid_hw else None
# Condition
template_cond = get_conv_template(self.template)
template_cond.system_message = system_message
template_cond.append_message(template_cond.roles[0], prompt)
template_cond.append_message(template_cond.roles[1], None)
query_cond = template_cond.get_prompt()
if not think_mode:
query_cond = query_cond + '<think>\n\n</think>\n\n'
def replace_image_tokens(query, grid_hw_list):
for i in range(len(grid_hw_list)):
num_patch_token = int(grid_hw_list[i][0, 0] * grid_hw_list[i][0, 1] * self.downsample_ratio**2)
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * num_patch_token + IMG_END_TOKEN
query = query.replace('<image>', image_tokens, 1)
return query
query_cond = replace_image_tokens(query_cond, grid_hw)
input_embeds_cond, indexes_cond, attention_mask_cond = self._build_it2i_inputs(tokenizer, query_cond, pv_tensor, ghw_tensor)
outputs_cond = self.language_model(inputs_embeds=input_embeds_cond, indexes=indexes_cond, attention_mask=attention_mask_cond, use_cache=True)
past_key_values_cond = outputs_cond.past_key_values
t_index_cond = indexes_cond[0].max().item()
# Initialize Text Uncondition Cache
question_text_uncondition = '<image>' * len(images)
template_tu = get_conv_template(self.template)
template_tu.system_message = self.system_message
template_tu.append_message(template_tu.roles[0], question_text_uncondition)
template_tu.append_message(template_tu.roles[1], None)
query_text_uncond = template_tu.get_prompt()
query_text_uncond = replace_image_tokens(query_text_uncond, grid_hw)
input_embeds_tu, indexes_tu, attention_mask_tu = self._build_it2i_inputs(tokenizer, query_text_uncond, pv_tensor, ghw_tensor)
outputs_tu = self.language_model(inputs_embeds=input_embeds_tu, indexes=indexes_tu, attention_mask=attention_mask_tu, use_cache=True)
past_key_values_tu = outputs_tu.past_key_values
t_index_tu = indexes_tu[0].max().item()
# Initialize Img (ALL) Uncondition Cache
query_img_uncond = self._build_t2i_query("", append_text=IMG_START_TOKEN)
input_embeds_iu, indexes_iu, attention_mask_iu = self._build_it2i_inputs(tokenizer, query_img_uncond)
outputs_iu = self.language_model(inputs_embeds=input_embeds_iu, indexes=indexes_iu, attention_mask=attention_mask_iu, use_cache=True)
past_key_values_iu = outputs_iu.past_key_values
generated_text = ""
generated_images =[]
max_images = 10
img_count = 0
next_token = torch.argmax(outputs_cond.logits[:, -1, :], dim=-1)
generator = torch.Generator(self.device).manual_seed(seed)
while True:
# text generation
gen_tokens = []
hit_max_tokens = False
last_decoded = 0
while True:
token_item = next_token.item()
if token_item == eos_token_id or token_item == self.img_start_token_id:
break
gen_tokens.append(token_item)
current_generated_tokens += 1
self.language_model.model.current_index = t_index_cond
outputs_cond = self.language_model(
input_ids=next_token.unsqueeze(0),
past_key_values=past_key_values_cond,
use_cache=True
)
past_key_values_cond = outputs_cond.past_key_values
t_index_cond += 1
next_token = torch.argmax(outputs_cond.logits[:, -1, :], dim=-1)
# Stream partial text so users see liveness during long runs
# (e.g. low VRAM offload). Decode in 16-token chunks.
if verbose and len(gen_tokens) - last_decoded >= 16:
partial = tokenizer.decode(gen_tokens[last_decoded:], skip_special_tokens=True)
print(partial, end='', flush=True)
last_decoded = len(gen_tokens)
if current_generated_tokens >= max_new_tokens:
hit_max_tokens = True
break
if len(gen_tokens) > 0:
chunk_text = tokenizer.decode(gen_tokens, skip_special_tokens=True)
generated_text += chunk_text
if verbose:
remaining = tokenizer.decode(gen_tokens[last_decoded:], skip_special_tokens=True)
if remaining:
print(remaining, end='', flush=True)
if next_token.item() == eos_token_id or hit_max_tokens:
break
if next_token.item() == self.img_start_token_id:
if img_count >= max_images:
break
generated_text += "<image>"
if verbose:
print(f"\n[image {img_count + 1}] preparing diffusion...", flush=True)
# Add the img_start_token for condition and text_uncondition branch
self.language_model.model.current_index = t_index_cond
outputs_cond = self.language_model(input_ids=next_token.unsqueeze(0), past_key_values=past_key_values_cond, use_cache=True)
past_key_values_cond = outputs_cond.past_key_values
t_index_cond += 1
self.language_model.model.current_index = t_index_tu
outputs_tu = self.language_model(input_ids=next_token.unsqueeze(0), past_key_values=past_key_values_tu, use_cache=True)
past_key_values_tu = outputs_tu.past_key_values
t_index_tu += 1
image_size = image_size_list[img_count]
# Image Generation
token_h = image_size[1] // (self.patch_size * merge_size)
token_w = image_size[0] // (self.patch_size * merge_size)
device = self.device
indexes_image_condition = self._build_t2i_image_indexes(token_h, token_w, t_index_cond + 1, device=device)
indexes_image_text_uncondition = self._build_t2i_image_indexes(token_h, token_w, t_index_tu + 1, device=device)
indexes_image_img_uncondition = self._build_t2i_image_indexes(token_h, token_w, indexes_iu[0].max() + 1, device=device)
grid_h = image_size[1] // self.patch_size
grid_w = image_size[0] // self.patch_size
gen_grid_hw = torch.tensor([[grid_h, grid_w]], device=device)
noise_scale = self.noise_scale
if self.noise_scale_mode in ("resolution", "dynamic", 'dynamic_sqrt'):
base = float(self.noise_scale_base_image_seq_len)
noise_scale = math.sqrt((grid_h*grid_w)/(merge_size**2)/base) * float(self.noise_scale)
if self.noise_scale_mode == 'dynamic_sqrt':
noise_scale = math.sqrt(noise_scale)
noise_scale = min(noise_scale, self.noise_scale_max_value)
image_prediction = noise_scale * torch.randn((1, 3, image_size[1], image_size[0]), device=device, dtype=outputs_cond.logits.dtype, generator=generator)
past_key_values_cond_cfg = past_key_values_cond
past_key_values_tu_cfg = past_key_values_tu
past_key_values_iu_cfg = past_key_values_iu
attention_mask_condition = {"full_attention": None}
attention_mask_text_uncondition = {"full_attention": None}
attention_mask_img_uncondition = {"full_attention": None}
prepare_flash_kv_cache(
past_key_values_cond_cfg,
current_len=token_h * token_w,
batch_size=1,
)
prepare_flash_kv_cache(
past_key_values_tu_cfg,
current_len=token_h * token_w,
batch_size=1,
)
prepare_flash_kv_cache(
past_key_values_iu_cfg,
current_len=token_h * token_w,
batch_size=1,
)
timesteps = torch.linspace(0.0, 1.0, num_steps+1, device=device)
if enable_timestep_shift:
timesteps = self._apply_time_schedule(timesteps, token_h*token_w, timestep_shift)
step_iter = range(num_steps)
if verbose:
try:
from tqdm import tqdm as _tqdm
step_iter = _tqdm(
step_iter,
desc=f"image {img_count + 1} ({image_size[0]}x{image_size[1]})",
total=num_steps,
leave=False,
)
except ImportError:
pass
for step_i in step_iter:
t = timesteps[step_i]
t_next = timesteps[step_i + 1]
z = self.patchify(image_prediction, self.patch_size * merge_size)
image_input = self.patchify(image_prediction, self.patch_size, channel_first=True)
image_embeds = self.extract_feature(image_input.view(1 * grid_h*grid_w, -1), gen_model=True, grid_hw=gen_grid_hw).view(1, token_h*token_w, -1)
t_expanded = t.expand(token_h*token_w)
timestep_embeddings = self.fm_modules['timestep_embedder'](t_expanded).view(1, token_h*token_w, -1)
if self.add_noise_scale_embedding:
noise_scale_tensor = torch.full_like(t_expanded, noise_scale/self.noise_scale_max_value)
noise_embeddings = self.fm_modules['noise_scale_embedder'](noise_scale_tensor).view(1, token_h*token_w, -1)
timestep_embeddings += noise_embeddings
image_embeds = image_embeds + timestep_embeddings
use_cfg = (t > cfg_interval[0] and t < cfg_interval[1]) or cfg_interval[0] == 0
out_cond = self._t2i_predict_v(image_embeds, indexes_image_condition, attention_mask_condition, past_key_values_cond_cfg, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings)
if not use_cfg:
v_pred = out_cond
elif cfg_scale == 1 and img_cfg_scale == 1:
v_pred = out_cond
elif img_cfg_scale == 1:
out_img_cond = self._t2i_predict_v(image_embeds, indexes_image_text_uncondition, attention_mask_text_uncondition, past_key_values_tu_cfg, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings)
v_pred = out_img_cond + cfg_scale * (out_cond - out_img_cond)
elif cfg_scale == img_cfg_scale:
out_uncond = self._t2i_predict_v(image_embeds, indexes_image_img_uncondition, attention_mask_img_uncondition, past_key_values_iu_cfg, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings)
v_pred = out_uncond + cfg_scale * (out_cond - out_uncond)
else:
out_img_cond = self._t2i_predict_v(image_embeds, indexes_image_text_uncondition, attention_mask_text_uncondition, past_key_values_tu_cfg, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings)
out_uncond = self._t2i_predict_v(image_embeds, indexes_image_img_uncondition, attention_mask_img_uncondition, past_key_values_iu_cfg, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings)
v_pred = (
out_uncond
+ cfg_scale * (out_cond - out_img_cond)
+ img_cfg_scale * (out_img_cond - out_uncond)
)
if (cfg_scale > 1 or img_cfg_scale > 1 and use_cfg):
if cfg_norm == 'global':
norm_v_condition = torch.norm(out_cond, dim=(1,2), keepdim=True)
norm_v_cfg = torch.norm(v_pred, dim=(1,2), keepdim=True)
scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0)
v_pred = v_pred * scale
elif cfg_norm == 'channel':
norm_v_condition = torch.norm(out_cond, dim=-1, keepdim=True)
norm_v_cfg = torch.norm(v_pred, dim=-1, keepdim=True)
scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0)
v_pred = v_pred * scale
z = z + (t_next - t) * v_pred
image_prediction = self.unpatchify(z, self.patch_size * merge_size, image_size[1], image_size[0])
generated_images.append(image_prediction)
clear_flash_kv_cache(past_key_values_cond_cfg)
clear_flash_kv_cache(past_key_values_tu_cfg)
clear_flash_kv_cache(past_key_values_iu_cfg)
img_count += 1
# re-encode the generated image using the und-branch
pred_img = image_prediction[0].unsqueeze(0).to(torch.bfloat16)
# re-normalize the image
raw_img = pred_img * 0.5 + 0.5
img_mean = torch.tensor([0.485, 0.456, 0.406], dtype=raw_img.dtype, device=device).view(1, 3, 1, 1)
img_std = torch.tensor([0.229, 0.224, 0.225], dtype=raw_img.dtype, device=device).view(1, 3, 1, 1)
und_img = (raw_img - img_mean) / img_std
c, h, w = und_img[0].shape
ps = self.patch_size
p_grid_h = h // ps
p_grid_w = w // ps
flatten_pixel_values = (
und_img[0].view(c, p_grid_h, ps, p_grid_w, ps)
.permute(1, 3, 0, 2, 4) # [grid_h, grid_w, c, patch_size, patch_size]
.reshape(p_grid_h * p_grid_w, c * ps ** 2)
)
vit_embeds = self.extract_feature(flatten_pixel_values, grid_hw=gen_grid_hw[:1]).unsqueeze(0)
img_end_id = tokenizer.convert_tokens_to_ids(IMG_END_TOKEN)
img_end_embed = self.language_model.get_input_embeddings()(torch.tensor([[img_end_id]], device=device))
inputs_embeds_img = torch.cat([vit_embeds, img_end_embed], dim=1) # (1, N + 1, C)
N_img_tokens = vit_embeds.shape[1]
abs_pos_w, abs_pos_h = build_abs_positions_from_grid_hw(gen_grid_hw[:1] // int(1 / self.downsample_ratio), device=device)
def append_image_to_cache(cache, t_idx):
past_len = cache.get_seq_length()
tgt_len = N_img_tokens + 1
t_indexes = torch.zeros(tgt_len, dtype=torch.long, device=device)
t_indexes[:N_img_tokens] = t_idx + 1
t_indexes[N_img_tokens] = t_idx + 2
h_indexes = torch.zeros(tgt_len, dtype=torch.long, device=device)
w_indexes = torch.zeros(tgt_len, dtype=torch.long, device=device)
h_indexes[:N_img_tokens] = abs_pos_h
w_indexes[:N_img_tokens] = abs_pos_w
indexes = torch.stack([t_indexes, h_indexes, w_indexes], dim=0)
mask = torch.zeros(1, 1, tgt_len, past_len + tgt_len, device=device)
mask[0, 0, :N_img_tokens, past_len + N_img_tokens] = float('-inf')
attention_mask_dict = {"full_attention": mask}
outputs = self.language_model(
inputs_embeds=inputs_embeds_img,
indexes=indexes,
attention_mask=attention_mask_dict,
past_key_values=cache,
use_cache=True
)
return outputs, t_idx + 2
outputs_cond, t_index_cond = append_image_to_cache(past_key_values_cond, t_index_cond)
outputs_tu, t_index_tu = append_image_to_cache(past_key_values_tu, t_index_tu)
next_token = torch.argmax(outputs_cond.logits[:, -1, :], dim=-1)
return generated_text, generated_images
@torch.no_grad()
def it2i_generate(self, tokenizer, prompt, images, cfg_scale=1, img_cfg_scale=1, cfg_norm='none', enable_timestep_shift=True, timestep_shift=1, image_size=(256, 256), num_steps=30, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', method='euler', cfg_interval=(0, 1), batch_size=1, t_eps=0.02, think_mode=False, seed=0):
assert cfg_norm in ['none', 'global', 'channel']
self.img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
self.config.t_eps = t_eps
image_token_count = prompt.count('<image>')
assert len(images) >= image_token_count
if len(images) > image_token_count:
if image_token_count == 0 and len(images) > 1:
prompt = "".join(f"Image-{i + 1}:<image>\n" for i in range(len(images))) + prompt
else:
prompt = "<image>\n" * (len(images) - image_token_count) + prompt
pixel_values = []
grid_hw = []
for image in images:
cur_pixel_values, cur_grid_hw = load_image_native(
image,
self.patch_size,
self.downsample_ratio,
min_pixels=512 * 512,
max_pixels=min(2048*2048, (4096 * 4096) // len(images)),
upscale=False,
)
cur_grid_hw = cur_grid_hw.to(self.device)
cur_pixel_values = cur_pixel_values.to(self.device).to(torch.bfloat16)
pixel_values.append(cur_pixel_values)
grid_hw.append(cur_grid_hw)
pixel_values = torch.cat(pixel_values)
grid_hw = torch.cat(grid_hw)
merge_size = int(1 / self.downsample_ratio)
question_condition = f"{prompt}"
think_text = ""
needs_cfg = not (cfg_scale == 1 and img_cfg_scale == 1)
needs_img_condition = needs_cfg and (img_cfg_scale == 1 or cfg_scale != img_cfg_scale)
needs_uncondition = needs_cfg and img_cfg_scale != 1
think_content = '<think>\n' if think_mode else '<think>\n\n</think>\n\n' + IMG_START_TOKEN
query_condition = self._build_t2i_query(question_condition, system_message=SYSTEM_MESSAGE_FOR_GEN, append_text=think_content)
query_img_condition = (
self._build_t2i_query('<image>' * len(images), append_text=IMG_START_TOKEN)
if needs_img_condition
else None
)
query_uncondition = self._build_t2i_query("", append_text=IMG_START_TOKEN) if needs_uncondition else None
for i in range(grid_hw.shape[0]):
num_patch_token = int(grid_hw[i, 0] * grid_hw[i, 1] * self.downsample_ratio**2)
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * num_patch_token + IMG_END_TOKEN
query_condition = query_condition.replace('<image>', image_tokens, 1)
if query_img_condition is not None:
query_img_condition = query_img_condition.replace('<image>', image_tokens, 1)
input_embeds_condition, indexes_condition, attention_mask_condition_prefix = self._build_it2i_inputs(
tokenizer, query_condition, pixel_values, grid_hw
)
if query_img_condition is not None:
input_embeds_img_condition, indexes_img_condition, attention_mask_img_condition_prefix = self._build_it2i_inputs(
tokenizer, query_img_condition, pixel_values, grid_hw
)
else:
input_embeds_img_condition = indexes_img_condition = attention_mask_img_condition_prefix = None
if query_uncondition is not None:
input_embeds_uncondition, indexes_uncondition, attention_mask_uncondition_prefix = self._build_it2i_inputs(
tokenizer, query_uncondition
)
else:
input_embeds_uncondition = indexes_uncondition = attention_mask_uncondition_prefix = None
token_h = image_size[1] // (self.patch_size * merge_size)
token_w = image_size[0] // (self.patch_size * merge_size)
indexes_image_condition = self._build_t2i_image_indexes(
token_h, token_w, indexes_condition[0].max() + 1, device=input_embeds_condition.device
)
indexes_image_img_condition = (
self._build_t2i_image_indexes(
token_h, token_w, indexes_img_condition[0].max() + 1, device=input_embeds_img_condition.device
)
if indexes_img_condition is not None
else None
)
indexes_image_uncondition = (
self._build_t2i_image_indexes(
token_h, token_w, indexes_uncondition[0].max() + 1, device=input_embeds_uncondition.device
)
if indexes_uncondition is not None
else None
)
if think_mode:
outputs_condition = self.language_model(
inputs_embeds=input_embeds_condition,
indexes=indexes_condition,
attention_mask=attention_mask_condition_prefix,
use_cache=True,
output_hidden_states=True,
)
past_key_values_condition = outputs_condition.past_key_values
hidden_states_condition = outputs_condition.hidden_states[-1]
t_index_condition = indexes_condition[0].max().item()
past_key_values_condition, t_index_condition, think_text = self._generate_think(
tokenizer,
outputs_condition,
past_key_values_condition,
t_index_condition,
IMG_START_TOKEN,
)
indexes_image_condition = self._build_t2i_image_indexes(
token_h, token_w, t_index_condition + 1, device=input_embeds_condition.device
)
else:
past_key_values_condition, hidden_states_condition = self._it2i_prefix_forward(
input_embeds_condition, indexes_condition, attention_mask_condition_prefix
)
past_key_values_img_condition = None
if input_embeds_img_condition is not None:
past_key_values_img_condition, _ = self._it2i_prefix_forward(
input_embeds_img_condition, indexes_img_condition, attention_mask_img_condition_prefix
)
past_key_values_uncondition = None
if input_embeds_uncondition is not None:
past_key_values_uncondition, _ = self._it2i_prefix_forward(
input_embeds_uncondition, indexes_uncondition, attention_mask_uncondition_prefix
)
device = hidden_states_condition.device
dtype = hidden_states_condition.dtype
del pixel_values, grid_hw
del input_embeds_condition, indexes_condition, attention_mask_condition_prefix
if input_embeds_img_condition is not None:
del input_embeds_img_condition, indexes_img_condition, attention_mask_img_condition_prefix
if input_embeds_uncondition is not None:
del input_embeds_uncondition, indexes_uncondition, attention_mask_uncondition_prefix
del hidden_states_condition
for layer_idx in range(len(past_key_values_condition.layers)):
past_key_values_condition.layers[layer_idx].keys = past_key_values_condition.layers[layer_idx].keys.expand(
batch_size, *past_key_values_condition.layers[layer_idx].keys.shape[1:]
)
past_key_values_condition.layers[layer_idx].values = past_key_values_condition.layers[layer_idx].values.expand(
batch_size, *past_key_values_condition.layers[layer_idx].values.shape[1:]
)
if past_key_values_img_condition is not None:
past_key_values_img_condition.layers[layer_idx].keys = past_key_values_img_condition.layers[layer_idx].keys.expand(
batch_size, *past_key_values_img_condition.layers[layer_idx].keys.shape[1:]
)
past_key_values_img_condition.layers[layer_idx].values = past_key_values_img_condition.layers[layer_idx].values.expand(
batch_size, *past_key_values_img_condition.layers[layer_idx].values.shape[1:]
)
if past_key_values_uncondition is not None:
past_key_values_uncondition.layers[layer_idx].keys = past_key_values_uncondition.layers[layer_idx].keys.expand(
batch_size, *past_key_values_uncondition.layers[layer_idx].keys.shape[1:]
)
past_key_values_uncondition.layers[layer_idx].values = past_key_values_uncondition.layers[layer_idx].values.expand(
batch_size, *past_key_values_uncondition.layers[layer_idx].values.shape[1:]
)
prepare_flash_kv_cache(
past_key_values_condition,
current_len=token_h * token_w,
batch_size=batch_size,
)
if past_key_values_img_condition is not None:
prepare_flash_kv_cache(
past_key_values_img_condition,
current_len=token_h * token_w,
batch_size=batch_size,
)
if past_key_values_uncondition is not None:
prepare_flash_kv_cache(
past_key_values_uncondition,
current_len=token_h * token_w,
batch_size=batch_size,
)
grid_h = image_size[1] // self.patch_size
grid_w = image_size[0] // self.patch_size
grid_hw = torch.tensor([[grid_h, grid_w]] * batch_size, device=device)
noise_scale = self.noise_scale
if self.noise_scale_mode in ("resolution", "dynamic", "dynamic_sqrt"):
base = float(self.noise_scale_base_image_seq_len)
scale = math.sqrt((grid_h * grid_w) / (merge_size**2) / base)
noise_scale = scale * float(self.noise_scale)
if self.noise_scale_mode == 'dynamic_sqrt':
noise_scale = math.sqrt(noise_scale)
noise_scale = min(noise_scale, self.noise_scale_max_value)
generator = torch.Generator(device).manual_seed(seed)
image_prediction = noise_scale * torch.randn(
(batch_size, 3, image_size[1], image_size[0]), device=device, dtype=dtype, generator=generator
)
attention_mask_condition = {"full_attention": None}
attention_mask_img_condition = {"full_attention": None}
attention_mask_uncondition = {"full_attention": None}
timesteps = torch.linspace(0.0, 1.0, num_steps + 1, device=device)
if enable_timestep_shift:
timesteps = self._apply_time_schedule(timesteps, token_h * token_w, timestep_shift)
for step_i in range(num_steps):
t = timesteps[step_i]
t_next = timesteps[step_i + 1]
use_cfg = (t > cfg_interval[0] and t < cfg_interval[1]) or cfg_interval[0] == 0
z = self.patchify(image_prediction, self.patch_size * merge_size)
image_input = self.patchify(image_prediction, self.patch_size, channel_first=True)
image_embeds = self.extract_feature(
image_input.view(batch_size * grid_h * grid_w, -1),
gen_model=True,
grid_hw=grid_hw,
).view(batch_size, token_h * token_w, -1)
t_expanded = t.expand(batch_size * token_h * token_w)
timestep_embeddings = self.fm_modules['timestep_embedder'](t_expanded).view(batch_size, token_h * token_w, -1)
if self.add_noise_scale_embedding:
noise_scale_tensor = torch.full_like(t_expanded, noise_scale / self.noise_scale_max_value)
noise_embeddings = self.fm_modules['noise_scale_embedder'](noise_scale_tensor).view(batch_size, token_h * token_w, -1)
timestep_embeddings += noise_embeddings
image_embeds = image_embeds + timestep_embeddings
out_cond = self._t2i_predict_v(
image_embeds,
indexes_image_condition,
attention_mask_condition,
past_key_values_condition,
t,
z,
image_token_num=token_h * token_w,
timestep_embeddings=timestep_embeddings,
image_size=image_size,
)
if not use_cfg:
v_pred = out_cond
elif cfg_scale == 1 and img_cfg_scale == 1:
v_pred = out_cond
elif img_cfg_scale == 1:
out_img_cond = self._t2i_predict_v(
image_embeds,
indexes_image_img_condition,
attention_mask_img_condition,
past_key_values_img_condition,
t,
z,
image_token_num=token_h * token_w,
timestep_embeddings=timestep_embeddings,
image_size=image_size,
)
v_pred = out_img_cond + cfg_scale * (out_cond - out_img_cond)
elif cfg_scale == img_cfg_scale:
out_uncond = self._t2i_predict_v(
image_embeds,
indexes_image_uncondition,
attention_mask_uncondition,
past_key_values_uncondition,
t,
z,
image_token_num=token_h * token_w,
timestep_embeddings=timestep_embeddings,
image_size=image_size,
)
v_pred = out_uncond + cfg_scale * (out_cond - out_uncond)
else:
out_img_cond = self._t2i_predict_v(
image_embeds,
indexes_image_img_condition,
attention_mask_img_condition,
past_key_values_img_condition,
t,
z,
image_token_num=token_h * token_w,
timestep_embeddings=timestep_embeddings,
image_size=image_size,
)
out_uncond = self._t2i_predict_v(
image_embeds,
indexes_image_uncondition,
attention_mask_uncondition,
past_key_values_uncondition,
t,
z,
image_token_num=token_h * token_w,
timestep_embeddings=timestep_embeddings,
image_size=image_size,
)
v_pred = (
out_uncond
+ cfg_scale * (out_cond - out_img_cond)
+ img_cfg_scale * (out_img_cond - out_uncond)
)
if (cfg_scale > 1 or img_cfg_scale > 1) and use_cfg:
if cfg_norm == 'global':
norm_v_condition = torch.norm(out_cond, dim=(1, 2), keepdim=True)
norm_v_cfg = torch.norm(v_pred, dim=(1, 2), keepdim=True)
scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0)
v_pred = v_pred * scale
elif cfg_norm == 'channel':
norm_v_condition = torch.norm(out_cond, dim=-1, keepdim=True)
norm_v_cfg = torch.norm(v_pred, dim=-1, keepdim=True)
scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0)
v_pred = v_pred * scale
z = z + (t_next - t) * v_pred
image_prediction = self.unpatchify(z, self.patch_size * merge_size, image_size[1], image_size[0])
clear_flash_kv_cache(past_key_values_condition)
if past_key_values_img_condition is not None:
clear_flash_kv_cache(past_key_values_img_condition)
if past_key_values_uncondition is not None:
clear_flash_kv_cache(past_key_values_uncondition)
self.last_think_content = think_text
if think_mode:
return image_prediction, think_text
return image_prediction
@torch.no_grad()
def t2i_generate(self, tokenizer, prompt, cfg_scale=1, timestep_shift=1, enable_timestep_shift=True, cfg_norm='none', image_size=(256, 256), num_steps=30, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', method='euler', cfg_interval=(0, 1), batch_size=1, t_eps=0.02, think_mode=False, seed=0):
assert self.concat_time_token_num == 0
assert cfg_norm in ['cfg_zero_star', 'global', 'none', 'channel']
merge_size = int(1 / self.downsample_ratio)
self.config.t_eps = t_eps
# question_condition = f"Please generate an image based on the following description: {prompt}"
question_condition = f"{prompt}"
# question_condition += f"\nThe resolution of the image should be {image_size}"
think_text = ""
needs_cfg = cfg_scale > 1
think_content = '<think>\n' if think_mode else '<think>\n\n</think>\n\n' + IMG_START_TOKEN
query_condition = self._build_t2i_query(question_condition, system_message=SYSTEM_MESSAGE_FOR_GEN, append_text=think_content)
query_uncondition = self._build_t2i_query("", append_text=IMG_START_TOKEN) if needs_cfg else None
input_ids_condition, indexes_condition, attention_mask_condition_prefix = self._build_t2i_text_inputs(tokenizer, query_condition)
if query_uncondition is not None:
input_ids_uncondition, indexes_uncondition, attention_mask_uncondition_prefix = self._build_t2i_text_inputs(tokenizer, query_uncondition)
else:
input_ids_uncondition = indexes_uncondition = attention_mask_uncondition_prefix = None
token_h = image_size[1] // (self.patch_size * merge_size)
token_w = image_size[0] // (self.patch_size * merge_size)
indexes_image_condition = self._build_t2i_image_indexes(token_h, token_w, indexes_condition.shape[1], device=input_ids_condition.device)
indexes_image_uncondition = (
self._build_t2i_image_indexes(token_h, token_w, indexes_uncondition.shape[1], device=input_ids_uncondition.device)
if indexes_uncondition is not None
else None
)
if think_mode:
outputs_condition = self.language_model(
input_ids=input_ids_condition,
indexes=indexes_condition,
attention_mask=attention_mask_condition_prefix,
use_cache=True,
output_hidden_states=True,
)
past_key_values_condition = outputs_condition.past_key_values
hidden_states_condition = outputs_condition.hidden_states[-1]
t_index_condition = indexes_condition[0].max().item()
past_key_values_condition, t_index_condition, think_text = self._generate_think(
tokenizer,
outputs_condition,
past_key_values_condition,
t_index_condition,
IMG_START_TOKEN,
)
indexes_image_condition = self._build_t2i_image_indexes(
token_h, token_w, t_index_condition + 1, device=input_ids_condition.device
)
else:
past_key_values_condition, hidden_states_condition = self._t2i_prefix_forward(input_ids_condition, indexes_condition, attention_mask_condition_prefix)
past_key_values_uncondition = None
if input_ids_uncondition is not None:
past_key_values_uncondition, _ = self._t2i_prefix_forward(input_ids_uncondition, indexes_uncondition, attention_mask_uncondition_prefix)
device = hidden_states_condition.device
dtype = hidden_states_condition.dtype
del input_ids_condition, indexes_condition, attention_mask_condition_prefix
if input_ids_uncondition is not None:
del input_ids_uncondition, indexes_uncondition, attention_mask_uncondition_prefix
del hidden_states_condition
for layer_idx in range(len(past_key_values_condition.layers)):
past_key_values_condition.layers[layer_idx].keys = past_key_values_condition.layers[layer_idx].keys.expand(batch_size, *past_key_values_condition.layers[layer_idx].keys.shape[1:])
past_key_values_condition.layers[layer_idx].values = past_key_values_condition.layers[layer_idx].values.expand(batch_size, *past_key_values_condition.layers[layer_idx].values.shape[1:])
if past_key_values_uncondition is not None:
past_key_values_uncondition.layers[layer_idx].keys = past_key_values_uncondition.layers[layer_idx].keys.expand(batch_size, *past_key_values_uncondition.layers[layer_idx].keys.shape[1:])
past_key_values_uncondition.layers[layer_idx].values = past_key_values_uncondition.layers[layer_idx].values.expand(batch_size, *past_key_values_uncondition.layers[layer_idx].values.shape[1:])
# prepare flash cache once
prepare_flash_kv_cache(
past_key_values_condition,
current_len=token_h * token_w,
batch_size=batch_size,
)
if past_key_values_uncondition is not None:
prepare_flash_kv_cache(
past_key_values_uncondition,
current_len=token_h * token_w,
batch_size=batch_size,
)
# init noise image tokens
grid_h = image_size[1] // self.patch_size
grid_w = image_size[0] // self.patch_size
grid_hw = torch.tensor([[grid_h, grid_w]]*batch_size, device=device)
noise_scale = self.noise_scale
if self.noise_scale_mode in ("resolution", "dynamic", 'dynamic_sqrt'):
base = float(self.noise_scale_base_image_seq_len)
scale = math.sqrt((grid_h*grid_w)/(merge_size**2)/base)
noise_scale = scale * float(self.noise_scale)
if self.noise_scale_mode == 'dynamic_sqrt':
noise_scale = math.sqrt(noise_scale)
noise_scale = min(noise_scale, self.noise_scale_max_value)
generator = torch.Generator(device).manual_seed(seed)
image_prediction = noise_scale * torch.randn((batch_size, 3, image_size[1], image_size[0]), device=device, dtype=dtype, generator=generator)
attention_mask_condition = {"full_attention": None}
attention_mask_uncondition = {"full_attention": None}
timesteps = torch.linspace(0.0, 1.0, num_steps+1, device=device)
if enable_timestep_shift:
timesteps = self._apply_time_schedule(timesteps, token_h*token_w, timestep_shift)
for step_i in range(num_steps):
t = timesteps[step_i]
t_next = timesteps[step_i + 1]
z = self.patchify(image_prediction, self.patch_size * merge_size)
image_input = self.patchify(image_prediction, self.patch_size, channel_first=True)
image_embeds = self.extract_feature(image_input.view(batch_size * grid_h*grid_w, -1), gen_model=True, grid_hw=grid_hw).view(batch_size, token_h*token_w, -1)
t_expanded = t.expand(batch_size*token_h*token_w)
timestep_embeddings = self.fm_modules['timestep_embedder'](t_expanded).view(batch_size, token_h*token_w, -1)
if self.add_noise_scale_embedding:
noise_scale_tensor = torch.full_like(t_expanded, noise_scale / self.noise_scale_max_value)
noise_embeddings = self.fm_modules['noise_scale_embedder'](noise_scale_tensor).view(batch_size, token_h*token_w, -1)
timestep_embeddings += noise_embeddings
image_embeds = image_embeds + timestep_embeddings
v_pred_condition = self._t2i_predict_v(image_embeds, indexes_image_condition, attention_mask_condition, past_key_values_condition, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings, image_size=image_size)
if t >= cfg_interval[0] and t <= cfg_interval[1] and cfg_scale > 1:
v_pred_uncondition = self._t2i_predict_v(image_embeds, indexes_image_uncondition, attention_mask_uncondition, past_key_values_uncondition, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings, image_size=image_size)
if cfg_norm == 'cfg_zero_star':
positive_flat = v_pred_condition.view(batch_size, -1)
negative_flat = v_pred_uncondition.view(batch_size, -1)
alpha = optimized_scale(positive_flat,negative_flat)
alpha = alpha.view(batch_size, *([1] * (len(v_pred_condition.shape) - 1)))
alpha = alpha.to(positive_flat.dtype)
if (step_i <= 0):
v_pred = v_pred_condition*0.
else:
v_pred = v_pred_uncondition * alpha + cfg_scale * (v_pred_condition - v_pred_uncondition * alpha)
else:
v_pred = v_pred_uncondition + cfg_scale * (v_pred_condition - v_pred_uncondition)
if cfg_norm == 'global':
norm_v_condition = torch.norm(v_pred_condition, dim=(1,2), keepdim=True)
norm_v_cfg = torch.norm(v_pred, dim=(1,2), keepdim=True)
scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0)
v_pred = v_pred * scale
elif cfg_norm == 'channel':
norm_v_condition = torch.norm(v_pred_condition, dim=-1, keepdim=True)
norm_v_cfg = torch.norm(v_pred, dim=-1, keepdim=True)
scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0)
v_pred = v_pred * scale
else:
v_pred = v_pred_condition
z = z + (t_next - t) * v_pred
image_prediction = self.unpatchify(z, self.patch_size * merge_size, image_size[1], image_size[0])
clear_flash_kv_cache(past_key_values_condition)
if past_key_values_uncondition is not None:
clear_flash_kv_cache(past_key_values_uncondition)
self.last_think_content = think_text
if think_mode:
return image_prediction, think_text
return image_prediction
def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False, grid_hw=None,
IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False):
if history is None and pixel_values is not None and '<image>' not in question:
question = '<image>\n' + question
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
self.img_context_token_id = img_context_token_id
self.img_start_token_id = tokenizer.convert_tokens_to_ids(IMG_START_TOKEN)
template = get_conv_template(self.template)
template.system_message = self.system_message
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
history = [] if history is None else history
for (old_question, old_answer) in history:
template.append_message(template.roles[0], old_question)
template.append_message(template.roles[1], old_answer)
template.append_message(template.roles[0], question)
template.append_message(template.roles[1], None)
query = template.get_prompt()
if verbose and pixel_values is not None:
print(f'dynamic image size: {grid_hw[0] * self.patch_size}')
for i in range(grid_hw.shape[0]):
num_patch_token = int(grid_hw[i, 0] * grid_hw[i, 1] * self.downsample_ratio**2)
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * num_patch_token + IMG_END_TOKEN
query = query.replace('<image>', image_tokens, 1)
model_inputs = tokenizer(query, return_tensors='pt')
input_ids = model_inputs['input_ids'].to(self.device)
attention_mask = model_inputs['attention_mask'].to(self.device)
generation_config['eos_token_id'] = eos_token_id
generation_output = self.generate(
pixel_values=pixel_values,
input_ids=input_ids,
grid_hw=grid_hw,
attention_mask=attention_mask,
**generation_config
)
response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
response = response.split(template.sep.strip())[0].strip()
history.append((question, response))
if return_history:
return response, history
else:
query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
if verbose:
print(query_to_print, response)
return response
@torch.no_grad()
def generate(
self,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: Optional[torch.FloatTensor] = None,
grid_hw: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
visual_features: Optional[torch.FloatTensor] = None,
generation_config: Optional[GenerationConfig] = None,
output_hidden_states: Optional[bool] = None,
**generate_kwargs,
) -> torch.LongTensor:
assert input_ids.shape[0] == 1
assert self.img_context_token_id is not None
indexes = self.get_thw_indexes(input_ids[0], grid_hw)
if pixel_values is not None:
if visual_features is not None:
vit_embeds = visual_features
else:
vit_embeds = self.extract_feature(pixel_values, grid_hw=grid_hw)
input_embeds = self.language_model.get_input_embeddings()(input_ids)
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
input_ids = input_ids.reshape(B * N)
selected = (input_ids == self.img_context_token_id)
assert selected.sum() != 0
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
input_embeds = input_embeds.reshape(B, N, C)
else:
input_embeds = self.language_model.get_input_embeddings()(input_ids)
outputs = self.language_model.generate(
inputs_embeds=input_embeds,
indexes=indexes,
attention_mask=attention_mask,
generation_config=generation_config,
output_hidden_states=output_hidden_states,
use_cache=True,
**generate_kwargs,
)
return outputs
@property
def lm_head(self):
return self.language_model.get_output_embeddings()
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
return self.language_model.set_input_embeddings(value)
def set_output_embeddings(self, value):
return self.language_model.set_output_embeddings(value)
def get_thw_indexes(self, input_ids, grid_hw=None):
img_start_shift = torch.cat([torch.zeros(1, dtype=torch.long).to(input_ids.device),
(input_ids == self.img_start_token_id).long()], dim=0)[:-1]
not_img_token = (input_ids != self.img_context_token_id).long()
t_indexes = ((img_start_shift + not_img_token).cumsum(0) - 1)
h_indexes = torch.zeros_like(t_indexes).to(t_indexes.device)
w_indexes = torch.zeros_like(t_indexes).to(t_indexes.device)
if grid_hw is not None:
selected = (input_ids == self.img_context_token_id)
if selected.long().sum() > 0:
abs_pos_w, abs_pos_h = build_abs_positions_from_grid_hw(
grid_hw // int(1 / self.downsample_ratio), device=t_indexes.device)
h_indexes[selected] = abs_pos_h.to(t_indexes.device, t_indexes.dtype)
w_indexes[selected] = abs_pos_w.to(t_indexes.device, t_indexes.dtype)
return torch.stack([t_indexes, h_indexes, w_indexes], dim=0)
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel
from .configuration_neo_vit import NEOVisionConfig
def precompute_rope_freqs_sincos(
dim: int, max_position: int, base: float = 10000.0, device=None
):
"""预计算 RoPE 的 cos 和 sin 值 (1D)。"""
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))
t = torch.arange(max_position, device=device).type_as(inv_freq)
freqs = torch.outer(t, inv_freq)
return torch.cos(freqs), torch.sin(freqs)
def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None):
"""
Compute patch coordinates (x, y)
Args:
grid_hw: (B, 2) tensor representing (H, W) per image
"""
device = grid_hw.device
B = grid_hw.shape[0]
# Get the number of patches per image
H = grid_hw[:, 0]
W = grid_hw[:, 1]
N = H * W
N_total = N.sum()
# Create the batch index for each patch (B x patch count)
patch_to_sample = torch.repeat_interleave(torch.arange(B, device=device), N) # (N_total,)
# Generate intra-image patch index (row-major order)
patch_id_within_image = torch.arange(N_total, device=device)
patch_id_within_image = patch_id_within_image - torch.cumsum(
torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0
)[patch_to_sample]
# Get H/W for each patch according to its image
W_per_patch = W[patch_to_sample]
abs_x = patch_id_within_image % W_per_patch
abs_y = patch_id_within_image // W_per_patch
return abs_x, abs_y
def apply_rotary_emb_1d(
x: torch.Tensor,
cos_cached: torch.Tensor,
sin_cached: torch.Tensor,
positions: torch.Tensor,
):
"""对输入张量的一部分应用1D RoPE。"""
# x: (..., seq_len, dim_part)
# positions: (..., seq_len)
# cos_cached: (max_pos, dim_part / 2)
cos = cos_cached[positions] # Shape: (positions.shape, dim_part / 2)
sin = sin_cached[positions] # Shape: (positions.shape, dim_part / 2)
x1 = x[..., 0::2]
x2 = x[..., 1::2]
rotated_x1 = x1 * cos - x2 * sin
rotated_x2 = x1 * sin + x2 * cos
x_rotated = torch.empty_like(x)
x_rotated[..., 0::2] = rotated_x1
x_rotated[..., 1::2] = rotated_x2
return x_rotated
def apply_2d_rotary_pos_emb(
x: torch.Tensor,
cos_cached_x: torch.Tensor,
sin_cached_x: torch.Tensor,
cos_cached_y: torch.Tensor,
sin_cached_y: torch.Tensor,
abs_positions_x: torch.Tensor,
abs_positions_y: torch.Tensor
):
"""应用2D RoPE到输入张量x。"""
dim = x.shape[-1]
dim_half = dim // 2
# 假设我们将embedding的前半部分用于一个方向的RoPE,后半部分用于另一个方向
# 例如,前一半给X坐标,后一半给Y坐标 (或者反过来,但要保持一致)
x_part_1 = x[..., :dim_half]
x_part_2 = x[..., dim_half:]
# 将与 abs_positions_x 相关的旋转应用于 x_part_1
rotated_part_1 = apply_rotary_emb_1d(
x_part_1, cos_cached_x, sin_cached_x, abs_positions_x
)
# 将与 abs_positions_y 相关的旋转应用于 x_part_2
rotated_part_2 = apply_rotary_emb_1d(
x_part_2, cos_cached_y, sin_cached_y, abs_positions_y
)
# 将它们重新拼接起来。确保顺序与你分割时一致。
return torch.cat((rotated_part_1, rotated_part_2), dim=-1)
class NEOVisionEmbeddings(nn.Module):
"""
Embedding Module for Vision.
"""
def __init__(self, config: NEOVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.llm_embed_dim = config.llm_hidden_size[0]
self.downsample_factor = int(1 / config.downsample_ratio[0])
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
)
self.dense_embedding = nn.Conv2d(
in_channels=self.embed_dim, out_channels=self.llm_embed_dim, kernel_size=self.downsample_factor, stride=self.downsample_factor
)
self.gelu = nn.GELU()
self.rope_dim_part = self.embed_dim // 2
cos_x, sin_x = precompute_rope_freqs_sincos(
self.rope_dim_part, config.max_position_embeddings_vision, base=config.rope_theta_vision, device=None
)
cos_y, sin_y = precompute_rope_freqs_sincos(
self.rope_dim_part, config.max_position_embeddings_vision, base=config.rope_theta_vision, device=None
)
self.register_buffer("cos_cached_x", cos_x, persistent=False)
self.register_buffer("sin_cached_x", sin_x, persistent=False)
self.register_buffer("cos_cached_y", cos_y, persistent=False)
self.register_buffer("sin_cached_y", sin_y, persistent=False)
def _apply_2d_rotary_pos_emb(self, patch_embeds, grid_hw):
"""
Apply 2D Rotary Position Embedding to the patch embeddings.
"""
abs_pos_x, abs_pos_y = build_abs_positions_from_grid_hw(grid_hw, device=patch_embeds.device)
embeddings = apply_2d_rotary_pos_emb(
patch_embeds.to(torch.float32), # RoPE calculations are often more stable in float32
self.cos_cached_x, self.sin_cached_x,
self.cos_cached_y, self.sin_cached_y,
abs_pos_x,
abs_pos_y
).to(self.patch_embedding.weight.dtype)
return embeddings
def forward(self, pixel_values: torch.FloatTensor, grid_hw=None) -> torch.Tensor:
pixel_values = pixel_values.view( #
-1,
3,
self.patch_size,
self.patch_size,
) # [28072, 768] -> [28072, 3, 16, 16]
patch_embeds = self.gelu(self.patch_embedding(pixel_values)).view(-1, self.embed_dim)
self.cos_cached_x = self.cos_cached_x.to(patch_embeds.device)
self.sin_cached_x = self.sin_cached_x.to(patch_embeds.device)
self.cos_cached_y = self.cos_cached_y.to(patch_embeds.device)
self.sin_cached_y = self.sin_cached_y.to(patch_embeds.device)
patch_embeds = self._apply_2d_rotary_pos_emb(patch_embeds, grid_hw) # [28072, 1024]
assert (grid_hw[:,0] * grid_hw[:,1]).sum() == patch_embeds.shape[0]
patches_list = []
cur_position = 0
for i in range(grid_hw.shape[0]):
h, w = grid_hw[i]
patches_per_img = patch_embeds[cur_position : cur_position + h * w].view(h, w, -1).unsqueeze(0)
patches_per_img = self.dense_embedding(patches_per_img.permute(0, 3, 1, 2))
patches_per_img = patches_per_img.permute(0, 2, 3, 1)
patches_list.append(patches_per_img.view(-1, patches_per_img.shape[-1]))
cur_position += h * w
embeddings = torch.cat(patches_list, dim=0) # (N_total // downsample_factor**2, C)
assert cur_position == patch_embeds.shape[0]
assert embeddings.shape[0] == int(patch_embeds.shape[0] / self.downsample_factor**2)
return embeddings
class NEOVisionModel(PreTrainedModel):
main_input_name = 'pixel_values'
_supports_flash_attn_2 = True
supports_gradient_checkpointing = True
config_class = NEOVisionConfig
# support transformers 4.51.+
_tp_plan = ''
def __init__(self, config: NEOVisionConfig):
super().__init__(config)
self.config = config
self.embeddings = NEOVisionEmbeddings(config)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_embeds: Optional[torch.FloatTensor] = None,
grid_hw: Optional[torch.Tensor] = None
) -> Union[Tuple, BaseModelOutputWithPooling]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None and pixel_embeds is None:
raise ValueError('You have to specify pixel_values or pixel_embeds')
if pixel_embeds is not None:
hidden_states = pixel_embeds
else:
assert pixel_values.dim() == 2, f"pixel_values must be 2D for native resolution, got: {pixel_values.dim()}"
hidden_states = self.embeddings(pixel_values, grid_hw=grid_hw)
return BaseModelOutputWithPooling(
last_hidden_state=hidden_states,
pooler_output=None,
hidden_states=None,
attentions=None,
)
from typing import Callable, Optional, Union
import torch
import torch._dynamo
from torch import nn
import copy
import math
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation import GenerationMixin
from transformers.integrations import use_kernel_forward_from_hub
from transformers.masking_utils import create_causal_mask
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import (
GenericForQuestionAnswering,
GenericForSequenceClassification,
GenericForTokenClassification,
GradientCheckpointingLayer,
)
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import check_model_inputs
from transformers import Qwen3Config
try:
from flash_attn import flash_attn_func # type: ignore
_HAS_FLASH_ATTN = True
except ImportError: # pragma: no cover - exercised only in CPU-only / no-flash envs
flash_attn_func = None # type: ignore
_HAS_FLASH_ATTN = False
# Attention backend dispatch.
#
# Set via :func:`set_attn_backend`. Three modes are accepted:
# * ``"auto"`` - use flash-attn if available, otherwise SDPA (default).
# * ``"flash"`` - force flash-attn; raise if ``flash_attn`` is not installed.
# * ``"sdpa"`` - force torch SDPA (useful for reproducibility tests and
# debugging, even when flash-attn is available).
_VALID_ATTN_BACKENDS = ("auto", "flash", "sdpa")
_ATTN_BACKEND: str = "auto"
def set_attn_backend(backend: str) -> str:
"""Choose the attention kernel used by the Qwen3 layers at runtime.
Returns the backend string that was set. Raises ``ValueError`` for an
unknown name and ``RuntimeError`` if ``flash`` is requested but the
``flash_attn`` package isn't importable.
"""
global _ATTN_BACKEND
backend = backend.lower()
if backend not in _VALID_ATTN_BACKENDS:
raise ValueError(
f"Unknown attention backend {backend!r}. "
f"Expected one of {_VALID_ATTN_BACKENDS}."
)
if backend == "flash" and not _HAS_FLASH_ATTN:
raise RuntimeError(
"Requested attn_backend='flash' but `flash_attn` is not installed. "
"Install it (e.g. `uv pip install <flash_attn-*.whl>`) or use "
"'auto' / 'sdpa'."
)
_ATTN_BACKEND = backend
return _ATTN_BACKEND
def get_attn_backend() -> str:
"""Return the currently active attention backend name."""
return _ATTN_BACKEND
def effective_attn_backend() -> str:
"""Resolve ``'auto'`` to the kernel that will actually run."""
if _ATTN_BACKEND != "auto":
return _ATTN_BACKEND
return "flash" if _HAS_FLASH_ATTN else "sdpa"
def _sdpa_attn_func(q, k, v, dropout_p: float = 0.0, softmax_scale=None, causal: bool = False):
"""Drop-in SDPA fallback for ``flash_attn_func``.
``flash_attn_func`` expects q/k/v in layout ``[B, S, H, D]`` and returns
``[B, S_q, H_q, D]``. ``torch.nn.functional.scaled_dot_product_attention``
expects ``[B, H, S, D]``; we transpose in and out.
``flash_attn_func`` natively handles Grouped-Query Attention (GQA) where
``H_q > H_kv``. Plain ``scaled_dot_product_attention`` only supports that
via the ``enable_gqa=True`` kwarg (torch >= 2.5). For broader compatibility
we just materialize the repeat manually when needed.
"""
q_bhsd = q.transpose(1, 2)
k_bhsd = k.transpose(1, 2)
v_bhsd = v.transpose(1, 2)
h_q = q_bhsd.shape[1]
h_kv = k_bhsd.shape[1]
if h_q != h_kv:
if h_q % h_kv != 0:
raise ValueError(
f"Cannot broadcast key/value heads ({h_kv}) to query heads ({h_q}): not divisible."
)
n_rep = h_q // h_kv
k_bhsd = k_bhsd.repeat_interleave(n_rep, dim=1)
v_bhsd = v_bhsd.repeat_interleave(n_rep, dim=1)
# SDPA does not support an explicit `scale` argument on older torch
# versions; fall back to the manual path in that case.
try:
out = torch.nn.functional.scaled_dot_product_attention(
q_bhsd,
k_bhsd,
v_bhsd,
dropout_p=dropout_p,
is_causal=causal,
scale=softmax_scale,
)
except TypeError:
if softmax_scale is not None:
q_bhsd = q_bhsd * softmax_scale
out = torch.nn.functional.scaled_dot_product_attention(
q_bhsd,
k_bhsd,
v_bhsd,
dropout_p=dropout_p,
is_causal=causal,
)
else:
out = torch.nn.functional.scaled_dot_product_attention(
q_bhsd,
k_bhsd,
v_bhsd,
dropout_p=dropout_p,
is_causal=causal,
)
return out.transpose(1, 2).contiguous()
def _flash_or_sdpa(q, k, v, dropout_p: float = 0.0, softmax_scale=None, causal: bool = False):
backend = effective_attn_backend()
# flash-attn ships CUDA kernels only. On XPU / CPU we transparently fall
# back to SDPA even if the user asked for ``flash`` — the alternative
# (crashing on first forward) is worse, and ``set_attn_backend('flash')``
# already guarded against the "package missing" case.
if backend == "flash" and q.device.type == "cuda":
return flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=causal)
return _sdpa_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=causal)
def create_block_causal_mask(index: torch.Tensor):
"""
index: (L)
return: (1, 1, L, L) block-wise causal attention mask
"""
L = index.size(0)
idx_i = index.unsqueeze(1).expand(L, L)
idx_j = index.unsqueeze(0).expand(L, L)
arange = torch.arange(L, device=index.device)
mask = (idx_j == idx_i) | (arange.unsqueeze(0) <= arange.unsqueeze(1))
return torch.where(mask[None, None, :, :] > 0, torch.tensor(0.0), torch.tensor(float('-inf')))
def visualize_mask(mask: torch.Tensor, i: int = 0, j: int = 12):
"""
mask: (1,1, L, L)
"""
submask = torch.where(mask[0, 0, :, :] == 0, torch.tensor(1.0), torch.tensor(0.0))
submask = mask[i:j, i:j].int().cpu().numpy()
for row in submask:
print(" ".join(map(str, row)))
@use_kernel_forward_from_hub("RMSNorm")
class Qwen3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
"""
Qwen3RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Qwen3MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
def _compute_default_rope_parameters(config, device=None, **_kwargs):
"""Default RoPE frequencies, inlined to avoid breakage across transformers versions.
transformers <=4.x exposes this as ``ROPE_INIT_FUNCTIONS["default"]``, but
5.x dropped the ``"default"`` key from that table. Having a local copy keeps
``Qwen3RotaryEmbedding`` working on both.
"""
base = config.rope_theta
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)
attention_factor = 1.0
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim)
)
return inv_freq, attention_factor
class Qwen3RotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, config: Qwen3Config, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
if self.rope_type == "default" or self.rope_type is None:
base_rope_init_fn = _compute_default_rope_parameters
else:
base_rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
def _rope_init_fn_keep_freq_range(cfg: Qwen3Config, dev=None):
inv_freq, attention_scaling = base_rope_init_fn(cfg, dev)
cfg2 = copy.deepcopy(cfg)
head_dim = getattr(cfg2, "head_dim", None)
if head_dim is None:
head_dim = cfg2.hidden_size // cfg2.num_attention_heads
setattr(cfg2, "head_dim", head_dim)
cfg2.head_dim = int(head_dim) * 2
inv_freq_full, _ = base_rope_init_fn(cfg2, dev)
inv_freq = inv_freq_full[::2]
return inv_freq, attention_scaling
self.rope_init_fn = _rope_init_fn_keep_freq_range
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Qwen3Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Qwen3Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.q_proj_mot_gen = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj_mot_gen = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj_mot_gen = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.o_proj_mot_gen = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.q_norm = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
self.q_norm_mot_gen = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps)
self.q_norm_hw = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps)
self.q_norm_hw_mot_gen = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps)
self.k_norm = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
self.k_norm_mot_gen = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps)
self.k_norm_hw = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
self.k_norm_hw_mot_gen = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
t_config = copy.deepcopy(config)
t_config.head_dim = config.head_dim // 2
self.rotary_emb = Qwen3RotaryEmbedding(config=t_config)
hw_config = copy.deepcopy(config)
hw_config.head_dim = config.head_dim // 4
hw_config.rope_theta = config.rope_theta_hw
hw_config.max_position_embeddings = config.max_position_embeddings_hw
self.rotary_emb_hw = Qwen3RotaryEmbedding(config=hw_config)
def forward_und(
self,
hidden_states: torch.Tensor,
indexes: Optional[torch.LongTensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert self.config._attn_implementation == "eager"
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape)
query_states_t, query_states_hw = query_states.chunk(2, dim=-1)
query_states_t = self.q_norm(query_states_t).transpose(1, 2)
query_states_hw = self.q_norm_hw(query_states_hw).transpose(1, 2)
query_states_h, query_states_w = query_states_hw.chunk(2, dim=-1)
key_states = self.k_proj(hidden_states).view(hidden_shape)
key_states_t, key_states_hw = key_states.chunk(2, dim=-1)
key_states_t = self.k_norm(key_states_t).transpose(1, 2)
key_states_hw = self.k_norm_hw(key_states_hw).transpose(1, 2)
key_states_h, key_states_w = key_states_hw.chunk(2, dim=-1)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos_t, sin_t = self.rotary_emb(hidden_states, indexes[0].unsqueeze(0))
query_states_t, key_states_t = apply_rotary_pos_emb(query_states_t, key_states_t, cos_t, sin_t)
cos_h, sin_h = self.rotary_emb_hw(hidden_states, indexes[1].unsqueeze(0))
query_states_h, key_states_h = apply_rotary_pos_emb(query_states_h, key_states_h, cos_h, sin_h)
cos_w, sin_w = self.rotary_emb_hw(hidden_states, indexes[2].unsqueeze(0))
query_states_w, key_states_w = apply_rotary_pos_emb(query_states_w, key_states_w, cos_w, sin_w)
query_states = torch.cat([query_states_t, query_states_h, query_states_w], dim=-1)
key_states = torch.cat([key_states_t, key_states_h, key_states_w], dim=-1)
if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
# cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
# key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
update_cache = kwargs.get("update_cache", True)
if update_cache:
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs=None)
else:
# only use the past key values but do not append the current one
layer = past_key_values.layers[self.layer_idx]
past_k, past_v = layer.keys, layer.values
if past_k is not None:
key_states = torch.cat([past_k, key_states], dim=2) # concat on seq_len
value_states = torch.cat([past_v, value_states], dim=2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window, # diff with Llama
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
# def forward_gen(
# self,
# hidden_states: torch.Tensor,
# indexes: Optional[torch.LongTensor],
# attention_mask: Optional[torch.Tensor],
# past_key_values: Optional[Cache] = None,
# cache_position: Optional[torch.LongTensor] = None,
# **kwargs: Unpack[FlashAttentionKwargs],
# ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
# assert self.config._attn_implementation == "eager"
# input_shape = hidden_states.shape[:-1]
# hidden_shape = (*input_shape, -1, self.head_dim)
# query_states = self.q_proj_mot_gen(hidden_states).view(hidden_shape)
# query_states_t, query_states_hw = query_states.chunk(2, dim=-1)
# query_states_t = self.q_norm_mot_gen(query_states_t).transpose(1, 2)
# query_states_hw = self.q_norm_hw_mot_gen(query_states_hw).transpose(1, 2)
# query_states_h, query_states_w = query_states_hw.chunk(2, dim=-1)
# key_states = self.k_proj_mot_gen(hidden_states).view(hidden_shape)
# key_states_t, key_states_hw = key_states.chunk(2, dim=-1)
# key_states_t = self.k_norm_mot_gen(key_states_t).transpose(1, 2)
# key_states_hw = self.k_norm_hw_mot_gen(key_states_hw).transpose(1, 2)
# key_states_h, key_states_w = key_states_hw.chunk(2, dim=-1)
# value_states = self.v_proj_mot_gen(hidden_states).view(hidden_shape).transpose(1, 2)
# cos_t, sin_t = self.rotary_emb(hidden_states, indexes[0].unsqueeze(0))
# query_states_t, key_states_t = apply_rotary_pos_emb(query_states_t, key_states_t, cos_t, sin_t)
# cos_h, sin_h = self.rotary_emb_hw(hidden_states, indexes[1].unsqueeze(0))
# query_states_h, key_states_h = apply_rotary_pos_emb(query_states_h, key_states_h, cos_h, sin_h)
# cos_w, sin_w = self.rotary_emb_hw(hidden_states, indexes[2].unsqueeze(0))
# query_states_w, key_states_w = apply_rotary_pos_emb(query_states_w, key_states_w, cos_w, sin_w)
# query_states = torch.cat([query_states_t, query_states_h, query_states_w], dim=-1)
# key_states = torch.cat([key_states_t, key_states_h, key_states_w], dim=-1)
# if past_key_values is not None:
# # sin and cos are specific to RoPE models; cache_position needed for the static cache
# # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
# # key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
# update_cache = kwargs.get("update_cache", True)
# if update_cache:
# key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs=None)
# else:
# # only use the past key values but do not append the current one
# layer = past_key_values.layers[self.layer_idx]
# past_k, past_v = layer.keys, layer.values
# if past_k is not None:
# key_states = torch.cat([past_k, key_states], dim=2) # concat on seq_len
# value_states = torch.cat([past_v, value_states], dim=2)
# attention_interface: Callable = eager_attention_forward
# if self.config._attn_implementation != "eager":
# attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
# attn_output, attn_weights = attention_interface(
# self,
# query_states,
# key_states,
# value_states,
# attention_mask,
# dropout=0.0 if not self.training else self.attention_dropout,
# scaling=self.scaling,
# sliding_window=self.sliding_window, # diff with Llama
# **kwargs,
# )
# attn_output = attn_output.reshape(*input_shape, -1).contiguous()
# attn_output = self.o_proj_mot_gen(attn_output)
# return attn_output, attn_weights
def forward_gen(
self,
hidden_states: torch.Tensor,
indexes: Optional[torch.LongTensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
# -----------------------------
# Build q / k / v for current tokens
# Internal layout before flash:
# q/k/v: [B, H, S, D]
# Flash layout:
# q/k/v: [B, S, H, D]
# -----------------------------
query_states = self.q_proj_mot_gen(hidden_states).view(hidden_shape)
query_states_t, query_states_hw = query_states.chunk(2, dim=-1)
query_states_t = self.q_norm_mot_gen(query_states_t).transpose(1, 2) # [B,H,S,D/2]
query_states_hw = self.q_norm_hw_mot_gen(query_states_hw).transpose(1, 2)
query_states_h, query_states_w = query_states_hw.chunk(2, dim=-1)
key_states = self.k_proj_mot_gen(hidden_states).view(hidden_shape)
key_states_t, key_states_hw = key_states.chunk(2, dim=-1)
key_states_t = self.k_norm_mot_gen(key_states_t).transpose(1, 2) # [B,H,S,D/2]
key_states_hw = self.k_norm_hw_mot_gen(key_states_hw).transpose(1, 2)
key_states_h, key_states_w = key_states_hw.chunk(2, dim=-1)
value_states = self.v_proj_mot_gen(hidden_states).view(hidden_shape).transpose(1, 2) # [B,H,S,D]
# RoPE
cos_t, sin_t = self.rotary_emb(hidden_states, indexes[0].unsqueeze(0))
query_states_t, key_states_t = apply_rotary_pos_emb(query_states_t, key_states_t, cos_t, sin_t)
cos_h, sin_h = self.rotary_emb_hw(hidden_states, indexes[1].unsqueeze(0))
query_states_h, key_states_h = apply_rotary_pos_emb(query_states_h, key_states_h, cos_h, sin_h)
cos_w, sin_w = self.rotary_emb_hw(hidden_states, indexes[2].unsqueeze(0))
query_states_w, key_states_w = apply_rotary_pos_emb(query_states_w, key_states_w, cos_w, sin_w)
# concat along head_dim
# query/key current layout: [B, H, S, D]
query_states = torch.cat([query_states_t, query_states_h, query_states_w], dim=-1)
key_states = torch.cat([key_states_t, key_states_h, key_states_w], dim=-1)
update_cache = kwargs.get("update_cache", True)
# ------------------------------------------------------------------
# Flash path:
# Only use when there is no explicit dense mask.
# This is exactly the t2i denoising use case:
# current image tokens attend to [prefix + current image tokens]
# fully bidirectional inside current block => causal=False
# ------------------------------------------------------------------
if attention_mask is None:
# Convert current q/k/v to flash layout [B, S, H, D]
q = query_states.transpose(1, 2).contiguous()
k_cur = key_states.transpose(1, 2).contiguous()
v_cur = value_states.transpose(1, 2).contiguous()
if past_key_values is not None:
if update_cache:
# Rare path, keep compatibility.
# past_key_values.update expects [B,H,S,D]
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx, cache_kwargs=None
)
k = key_states.transpose(1, 2).contiguous()
v = value_states.transpose(1, 2).contiguous()
else:
# Optimized path:
# use preallocated flash_k_cache / flash_v_cache
layer = past_key_values.layers[self.layer_idx]
if (
hasattr(layer, "flash_k_cache")
and layer.flash_k_cache is not None
and hasattr(layer, "flash_v_cache")
and layer.flash_v_cache is not None
):
prefix_len = layer.flash_prefix_len
cur_len = k_cur.shape[1]
# overwrite current segment in-place
layer.flash_k_cache[:, prefix_len:prefix_len + cur_len].copy_(k_cur)
layer.flash_v_cache[:, prefix_len:prefix_len + cur_len].copy_(v_cur)
k = layer.flash_k_cache[:, :prefix_len + cur_len]
v = layer.flash_v_cache[:, :prefix_len + cur_len]
else:
# fallback if user forgot to prepare flash cache
layer = past_key_values.layers[self.layer_idx]
past_k, past_v = layer.keys, layer.values
if past_k is not None:
past_k = past_k.transpose(1, 2).contiguous()
past_v = past_v.transpose(1, 2).contiguous()
k = torch.cat([past_k, k_cur], dim=1)
v = torch.cat([past_v, v_cur], dim=1)
else:
k = k_cur
v = v_cur
else:
k = k_cur
v = v_cur
# sanity checks
assert q.ndim == 4 and k.ndim == 4 and v.ndim == 4
assert q.shape[0] == k.shape[0] == v.shape[0], (q.shape, k.shape, v.shape)
assert k.shape[1] == v.shape[1], (k.shape, v.shape)
assert k.shape[2] == v.shape[2], (k.shape, v.shape)
assert q.shape[3] == k.shape[3] == v.shape[3], (q.shape, k.shape, v.shape)
attn_output = _flash_or_sdpa(
q,
k,
v,
dropout_p=0.0 if not self.training else self.attention_dropout,
softmax_scale=self.scaling,
causal=False,
) # [B, S_q, H_q, D]
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj_mot_gen(attn_output)
return attn_output, None
# ------------------------------------------------------------------
# Original eager fallback path
# ------------------------------------------------------------------
if past_key_values is not None:
if update_cache:
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx, cache_kwargs=None
)
else:
layer = past_key_values.layers[self.layer_idx]
past_k, past_v = layer.keys, layer.values
if past_k is not None:
key_states = torch.cat([past_k, key_states], dim=2)
value_states = torch.cat([past_v, value_states], dim=2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj_mot_gen(attn_output)
return attn_output, attn_weights
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
image_gen_indicators: torch.Tensor,
exist_non_image_gen_tokens: bool,
exist_image_gen_tokens: bool,
indexes: Optional[torch.LongTensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if exist_non_image_gen_tokens and not exist_image_gen_tokens:
return self.forward_und(hidden_states, indexes, attention_mask, past_key_values, cache_position, **kwargs)
if not exist_non_image_gen_tokens and exist_image_gen_tokens:
return self.forward_gen(hidden_states, indexes, attention_mask, past_key_values, cache_position, **kwargs)
assert self.config._attn_implementation == "eager"
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = hidden_states.new_zeros((*input_shape, self.config.num_attention_heads*self.head_dim))
if exist_non_image_gen_tokens:
query_states[~image_gen_indicators] = self.q_proj(hidden_states[~image_gen_indicators])
if exist_image_gen_tokens:
query_states[image_gen_indicators] = self.q_proj_mot_gen(hidden_states[image_gen_indicators])
query_states_t, query_states_hw = query_states.chunk(2, dim=-1)
_query_states_hw = query_states_hw.new_zeros(query_states_hw.shape)
if exist_non_image_gen_tokens:
_query_states_hw[~image_gen_indicators] = self.q_norm_hw(query_states_hw[~image_gen_indicators])
if exist_image_gen_tokens:
_query_states_hw[image_gen_indicators] = self.q_norm_hw_mot_gen(query_states_h[image_gen_indicators])
query_states_hw = _query_states_hw.transpose(1, 2)
query_states_h, query_states_w = query_states_hw.chunk(2, dim=-1)
key_states = hidden_states.new_zeros((*input_shape, self.config.num_key_value_heads*self.head_dim))
if exist_non_image_gen_tokens:
key_states[~image_gen_indicators] = self.k_proj(hidden_states[~image_gen_indicators])
if exist_image_gen_tokens:
key_states[image_gen_indicators] = self.k_proj_mot_gen(hidden_states[image_gen_indicators])
key_states_t, key_states_hw = key_states.chunk(2, dim=-1)
_key_states_hw = key_states_hw.new_zeros(key_states_hw.shape)
if exist_non_image_gen_tokens:
_key_states_hw[~image_gen_indicators] = self.k_norm_hw(key_states_hw[~image_gen_indicators])
if exist_image_gen_tokens:
_key_states_hw[image_gen_indicators] = self.k_norm_hw_mot_gen(key_states_h[image_gen_indicators])
key_states_hw = _key_states_hw.transpose(1, 2)
key_states_h, key_states_w = key_states_hw.chunk(2, dim=-1)
value_states = hidden_states.new_zeros((*input_shape, self.config.num_key_value_heads*self.head_dim))
if exist_non_image_gen_tokens:
value_states[~image_gen_indicators] = self.v_proj(hidden_states[~image_gen_indicators])
if exist_image_gen_tokens:
value_states[image_gen_indicators] = self.v_proj_mot_gen(hidden_states[image_gen_indicators])
value_states = value_states.view(hidden_shape).transpose(1, 2)
cos_t, sin_t = self.rotary_emb(hidden_states, indexes[0].unsqueeze(0))
query_states_t, key_states_t = apply_rotary_pos_emb(query_states_t, key_states_t, cos_t, sin_t)
cos_h, sin_h = self.rotary_emb_hw(hidden_states, indexes[1].unsqueeze(0))
query_states_h, key_states_h = apply_rotary_pos_emb(query_states_h, key_states_h, cos_h, sin_h)
cos_w, sin_w = self.rotary_emb_hw(hidden_states, indexes[2].unsqueeze(0))
query_states_w, key_states_w = apply_rotary_pos_emb(query_states_w, key_states_w, cos_w, sin_w)
query_states = torch.cat([query_states_t, query_states_h, query_states_w], dim=-1)
key_states = torch.cat([key_states_t, key_states_h, key_states_w], dim=-1)
if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
# cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
# key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
update_cache = kwargs.get("update_cache", True)
if update_cache:
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs=None)
else:
# only use the past key values but do not append the current one
layer = past_key_values.layers[self.layer_idx]
past_k, past_v = layer.keys, layer.values
if past_k is not None:
key_states = torch.cat([past_k, key_states], dim=2) # concat on seq_len
value_states = torch.cat([past_v, value_states], dim=2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window, # diff with Llama
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
_attn_output = attn_output.new_zeros((*input_shape, self.config.hidden_size))
if exist_non_image_gen_tokens:
_attn_output[~image_gen_indicators] = self.o_proj(attn_output[~image_gen_indicators])
if exist_image_gen_tokens:
_attn_output[image_gen_indicators] = self.o_proj_mot_gen(attn_output[image_gen_indicators])
attn_output = _attn_output
return attn_output, attn_weights
class Qwen3DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Qwen3Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx)
self.mlp = Qwen3MLP(config)
self.mlp_mot_gen = Qwen3MLP(config)
self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm_mot_gen = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm_mot_gen = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attention_type = config.layer_types[layer_idx]
def forward_und(
self,
hidden_states: torch.Tensor,
image_gen_indicators: torch.Tensor,
exist_non_image_gen_tokens: bool,
exist_image_gen_tokens: bool,
indexes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
image_gen_indicators=image_gen_indicators,
exist_non_image_gen_tokens=exist_non_image_gen_tokens,
exist_image_gen_tokens=exist_image_gen_tokens,
indexes=indexes,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
def forward_gen(
self,
hidden_states: torch.Tensor,
image_gen_indicators: torch.Tensor,
exist_non_image_gen_tokens: bool,
exist_image_gen_tokens: bool,
indexes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm_mot_gen(hidden_states)
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
image_gen_indicators=image_gen_indicators,
exist_non_image_gen_tokens=exist_non_image_gen_tokens,
exist_image_gen_tokens=exist_image_gen_tokens,
indexes=indexes,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm_mot_gen(hidden_states)
hidden_states = self.mlp_mot_gen(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
image_gen_indicators: torch.Tensor,
exist_non_image_gen_tokens: bool,
exist_image_gen_tokens: bool,
indexes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
if exist_non_image_gen_tokens and not exist_image_gen_tokens:
return self.forward_und(hidden_states, image_gen_indicators, exist_non_image_gen_tokens, exist_image_gen_tokens, indexes, attention_mask, position_ids, past_key_values, use_cache, cache_position, **kwargs)
if not exist_non_image_gen_tokens and exist_image_gen_tokens:
return self.forward_gen(hidden_states, image_gen_indicators, exist_non_image_gen_tokens, exist_image_gen_tokens, indexes, attention_mask, position_ids, past_key_values, use_cache, cache_position, **kwargs)
residual = hidden_states
_hidden_states = hidden_states.new_zeros(hidden_states.shape)
if exist_non_image_gen_tokens:
_hidden_states[~image_gen_indicators] = self.input_layernorm(hidden_states[~image_gen_indicators])
if exist_image_gen_tokens:
_hidden_states[image_gen_indicators] = self.input_layernorm_mot_gen(hidden_states[image_gen_indicators])
hidden_states = _hidden_states
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
image_gen_indicators=image_gen_indicators,
exist_non_image_gen_tokens=exist_non_image_gen_tokens,
exist_image_gen_tokens=exist_image_gen_tokens,
indexes=indexes,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
_hidden_states = hidden_states.new_zeros(hidden_states.shape)
if exist_non_image_gen_tokens:
_hidden_states[~image_gen_indicators] = self.mlp(self.post_attention_layernorm(hidden_states[~image_gen_indicators]))
if exist_image_gen_tokens:
_hidden_states[image_gen_indicators] = self.mlp_mot_gen(self.post_attention_layernorm_mot_gen(hidden_states[image_gen_indicators]))
hidden_states = _hidden_states
hidden_states = residual + hidden_states
return hidden_states
@auto_docstring
class Qwen3PreTrainedModel(PreTrainedModel):
config: Qwen3Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen3DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Qwen3DecoderLayer,
"attentions": Qwen3Attention,
}
@auto_docstring
class Qwen3Model(Qwen3PreTrainedModel):
def __init__(self, config: Qwen3Config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm_mot_gen = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
self.has_sliding_layers = "sliding_attention" in self.config.layer_types
self.current_index = -1
# Initialize weights and apply final processing
self.post_init()
@check_model_inputs
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
image_gen_indicators: Optional[torch.Tensor] = None,
indexes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
# assert position_ids is not None
# assert cache_position is not None
# assert past_key_values is not None
if image_gen_indicators is None:
exist_non_image_gen_tokens = True
exist_image_gen_tokens = False
else:
exist_non_image_gen_tokens = (~image_gen_indicators).any()
exist_image_gen_tokens = image_gen_indicators.any()
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# It may already have been prepared by e.g. `generate`
if not isinstance(causal_mask_mapping := attention_mask, dict):
# Prepare mask arguments
if input_ids is not None:
mask_kwargs = {
"config": self.config,
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
# Create the masks
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
}
self.current_index += 1
indexes = torch.LongTensor([[self.current_index], [0], [0]]).to(input_ids.device)
else:
causal_mask_mapping = {
"full_attention": create_block_causal_mask(indexes[0]),
}
self.current_index = indexes[0].max()
else:
self.current_index = indexes[0].max()
# raise NotImplementedError('not isinstance(causal_mask_mapping := attention_mask, dict)')
# The sliding window alternating layers are not always activated depending on the config
# if self.has_sliding_layers:
# causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
hidden_states = inputs_embeds
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
image_gen_indicators=image_gen_indicators,
exist_non_image_gen_tokens=exist_non_image_gen_tokens,
exist_image_gen_tokens=exist_image_gen_tokens,
indexes=indexes,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
if not exist_image_gen_tokens:
hidden_states = self.norm(hidden_states)
elif not exist_non_image_gen_tokens:
hidden_states = self.norm_mot_gen(hidden_states)
else:
_hidden_states = hidden_states.new_zeros(hidden_states.shape)
_hidden_states[~image_gen_indicators] = self.norm(hidden_states[~image_gen_indicators])
_hidden_states[image_gen_indicators] = self.norm_mot_gen(hidden_states[image_gen_indicators])
hidden_states = _hidden_states
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
)
@auto_docstring
class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config):
super().__init__(config)
self.model = Qwen3Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
indexes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Example:
```python
>>> from transformers import AutoTokenizer, Qwen3ForCausalLM
>>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
indexes=indexes,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=hidden_states,
attentions=outputs.attentions,
)
class Qwen3ForSequenceClassification(GenericForSequenceClassification, Qwen3PreTrainedModel):
pass
class Qwen3ForTokenClassification(GenericForTokenClassification, Qwen3PreTrainedModel):
pass
class Qwen3ForQuestionAnswering(GenericForQuestionAnswering, Qwen3PreTrainedModel):
base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
__all__ = [
"Qwen3ForCausalLM",
"Qwen3ForQuestionAnswering",
"Qwen3PreTrainedModel",
"Qwen3Model",
"Qwen3ForSequenceClassification",
"Qwen3ForTokenClassification",
]
\ No newline at end of file
from typing import Callable, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation import GenerationMixin
from transformers.masking_utils import create_causal_mask
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import check_model_inputs
from .configuration_neo_chat import NEOMoELLMConfig
from .modeling_qwen3 import (
Qwen3Attention,
Qwen3RMSNorm,
create_block_causal_mask,
)
class Qwen3MoeMLP(nn.Module):
"""Single expert FFN. Same structure as :class:`Qwen3MLP` but the
intermediate size is parameterised so it can be ``moe_intermediate_size``
(per-expert) for experts and ``intermediate_size`` for any dense fallback.
"""
def __init__(self, config, intermediate_size: Optional[int] = None):
super().__init__()
from transformers.activations import ACT2FN
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = (
intermediate_size if intermediate_size is not None else config.intermediate_size
)
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class Qwen3MoeSparseMoeBlock(nn.Module):
"""Top-k softmax-routed MoE block matching HuggingFace's Qwen3-MoE layout.
Parameter names (``gate.weight``, ``experts.{i}.gate_proj/up_proj/down_proj``)
are kept identical so converted A3B checkpoints load directly via the
``mlp.*`` / ``mlp_mot_gen.*`` keys. The block is parameterised explicitly
so the same class can serve both the understanding branch (``num_experts``
experts, top-k = ``num_experts_per_tok``, width ``moe_intermediate_size``)
and the image-generation branch (``gen_num_experts`` etc.).
"""
def __init__(
self,
config: NEOMoELLMConfig,
num_experts: Optional[int] = None,
num_experts_per_tok: Optional[int] = None,
moe_intermediate_size: Optional[int] = None,
):
super().__init__()
self.num_experts = int(num_experts) if num_experts is not None else int(config.num_experts)
self.top_k = int(
num_experts_per_tok if num_experts_per_tok is not None else config.num_experts_per_tok
)
self.norm_topk_prob = bool(getattr(config, "norm_topk_prob", True))
self.hidden_size = config.hidden_size
expert_intermediate_size = int(
moe_intermediate_size
if moe_intermediate_size is not None
else config.moe_intermediate_size
)
self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
self.experts = nn.ModuleList(
[
Qwen3MoeMLP(config, intermediate_size=expert_intermediate_size)
for _ in range(self.num_experts)
]
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_shape = hidden_states.shape
hidden_dim = orig_shape[-1]
flat = hidden_states.view(-1, hidden_dim)
n_tokens = flat.shape[0]
router_logits = self.gate(flat)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(flat.dtype)
output = torch.zeros(
(n_tokens, hidden_dim), dtype=flat.dtype, device=flat.device
)
# (num_experts, top_k, num_tokens)
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
for expert_idx in range(self.num_experts):
idx, top_x = torch.where(expert_mask[expert_idx])
if top_x.numel() == 0:
continue
expert_layer = self.experts[expert_idx]
current_state = flat.index_select(0, top_x)
current_out = expert_layer(current_state) * routing_weights[top_x, idx, None]
output.index_add_(0, top_x, current_out.to(flat.dtype))
return output.view(*orig_shape)
class Qwen3MoeDecoderLayer(GradientCheckpointingLayer):
"""A Qwen3-MoE decoder block with the NEO-Unify two-branch structure.
Mirrors ``Qwen3DecoderLayer`` from :mod:`modeling_qwen3` but uses sparse
MoE blocks on *both* branches:
* ``self.mlp`` - understanding-path MoE
(``num_experts`` / ``num_experts_per_tok`` /
``moe_intermediate_size``)
* ``self.mlp_mot_gen`` - image-generation-path MoE
(``gen_num_experts`` / ``gen_num_experts_per_tok`` /
``gen_moe_intermediate_size``)
Layers listed in ``mlp_only_layers`` or those not aligned with
``decoder_sparse_step`` fall back to a dense :class:`Qwen3MoeMLP` on the
understanding branch (matching upstream Qwen3-MoE), while the
generation branch still uses a sparse MoE.
"""
def __init__(self, config: NEOMoELLMConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx)
mlp_only_layers = list(getattr(config, "mlp_only_layers", []) or [])
decoder_sparse_step = int(getattr(config, "decoder_sparse_step", 1) or 1)
is_sparse = (
int(config.num_experts) > 0
and layer_idx not in mlp_only_layers
and (layer_idx + 1) % decoder_sparse_step == 0
)
if is_sparse:
self.mlp = Qwen3MoeSparseMoeBlock(
config,
num_experts=config.num_experts,
num_experts_per_tok=config.num_experts_per_tok,
moe_intermediate_size=config.moe_intermediate_size,
)
else:
self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size)
# Image-generation branch: in the A3B checkpoint this is *also* a sparse
# MoE block (``gen_num_experts`` experts, typically smaller than the und
# branch's ``num_experts``). ``NEOMoELLMConfig`` defaults the gen-path
# knobs to their und-path counterparts so legacy single-pool configs
# keep working.
self.mlp_mot_gen = Qwen3MoeSparseMoeBlock(
config,
num_experts=getattr(config, "gen_num_experts", config.num_experts),
num_experts_per_tok=getattr(
config, "gen_num_experts_per_tok", config.num_experts_per_tok
),
moe_intermediate_size=getattr(
config, "gen_moe_intermediate_size", config.moe_intermediate_size
),
)
self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm_mot_gen = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm_mot_gen = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attention_type = config.layer_types[layer_idx]
def forward_und(
self,
hidden_states: torch.Tensor,
image_gen_indicators: torch.Tensor,
exist_non_image_gen_tokens: bool,
exist_image_gen_tokens: bool,
indexes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
image_gen_indicators=image_gen_indicators,
exist_non_image_gen_tokens=exist_non_image_gen_tokens,
exist_image_gen_tokens=exist_image_gen_tokens,
indexes=indexes,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
def forward_gen(
self,
hidden_states: torch.Tensor,
image_gen_indicators: torch.Tensor,
exist_non_image_gen_tokens: bool,
exist_image_gen_tokens: bool,
indexes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm_mot_gen(hidden_states)
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
image_gen_indicators=image_gen_indicators,
exist_non_image_gen_tokens=exist_non_image_gen_tokens,
exist_image_gen_tokens=exist_image_gen_tokens,
indexes=indexes,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm_mot_gen(hidden_states)
hidden_states = self.mlp_mot_gen(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
image_gen_indicators: torch.Tensor,
exist_non_image_gen_tokens: bool,
exist_image_gen_tokens: bool,
indexes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
if exist_non_image_gen_tokens and not exist_image_gen_tokens:
return self.forward_und(
hidden_states, image_gen_indicators, exist_non_image_gen_tokens,
exist_image_gen_tokens, indexes, attention_mask, position_ids,
past_key_values, use_cache, cache_position, **kwargs,
)
if not exist_non_image_gen_tokens and exist_image_gen_tokens:
return self.forward_gen(
hidden_states, image_gen_indicators, exist_non_image_gen_tokens,
exist_image_gen_tokens, indexes, attention_mask, position_ids,
past_key_values, use_cache, cache_position, **kwargs,
)
# Mixed batch: dispatch tokens per branch then merge back. Matches the
# dense ``Qwen3DecoderLayer.forward`` mixed-path implementation.
residual = hidden_states
_hidden_states = hidden_states.new_zeros(hidden_states.shape)
if exist_non_image_gen_tokens:
_hidden_states[~image_gen_indicators] = self.input_layernorm(
hidden_states[~image_gen_indicators]
)
if exist_image_gen_tokens:
_hidden_states[image_gen_indicators] = self.input_layernorm_mot_gen(
hidden_states[image_gen_indicators]
)
hidden_states = _hidden_states
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
image_gen_indicators=image_gen_indicators,
exist_non_image_gen_tokens=exist_non_image_gen_tokens,
exist_image_gen_tokens=exist_image_gen_tokens,
indexes=indexes,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
_hidden_states = hidden_states.new_zeros(hidden_states.shape)
if exist_non_image_gen_tokens:
und_hidden = self.post_attention_layernorm(
hidden_states[~image_gen_indicators]
)
# MoE expects a 3D input (batch, seq, hidden); promote then squeeze.
if und_hidden.dim() == 2:
und_hidden = und_hidden.unsqueeze(0)
_hidden_states[~image_gen_indicators] = self.mlp(und_hidden).squeeze(0)
else:
_hidden_states[~image_gen_indicators] = self.mlp(und_hidden)
if exist_image_gen_tokens:
_hidden_states[image_gen_indicators] = self.mlp_mot_gen(
self.post_attention_layernorm_mot_gen(hidden_states[image_gen_indicators])
)
hidden_states = _hidden_states
hidden_states = residual + hidden_states
return hidden_states
@auto_docstring
class Qwen3MoePreTrainedModel(PreTrainedModel):
config: NEOMoELLMConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Qwen3MoeDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_can_compile_fullgraph = False # MoE routing has data-dependent control flow.
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Qwen3MoeDecoderLayer,
"attentions": Qwen3Attention,
}
@auto_docstring
class Qwen3MoeModel(Qwen3MoePreTrainedModel):
def __init__(self, config: NEOMoELLMConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[Qwen3MoeDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm_mot_gen = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
self.has_sliding_layers = "sliding_attention" in self.config.layer_types
self.current_index = -1
self.post_init()
@check_model_inputs
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
image_gen_indicators: Optional[torch.Tensor] = None,
indexes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if image_gen_indicators is None:
exist_non_image_gen_tokens = True
exist_image_gen_tokens = False
else:
exist_non_image_gen_tokens = (~image_gen_indicators).any()
exist_image_gen_tokens = image_gen_indicators.any()
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
if not isinstance(causal_mask_mapping := attention_mask, dict):
if input_ids is not None:
mask_kwargs = {
"config": self.config,
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
}
self.current_index += 1
indexes = torch.LongTensor([[self.current_index], [0], [0]]).to(input_ids.device)
else:
causal_mask_mapping = {
"full_attention": create_block_causal_mask(indexes[0]),
}
self.current_index = indexes[0].max()
else:
self.current_index = indexes[0].max()
hidden_states = inputs_embeds
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
image_gen_indicators=image_gen_indicators,
exist_non_image_gen_tokens=exist_non_image_gen_tokens,
exist_image_gen_tokens=exist_image_gen_tokens,
indexes=indexes,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
if not exist_image_gen_tokens:
hidden_states = self.norm(hidden_states)
elif not exist_non_image_gen_tokens:
hidden_states = self.norm_mot_gen(hidden_states)
else:
_hidden_states = hidden_states.new_zeros(hidden_states.shape)
_hidden_states[~image_gen_indicators] = self.norm(hidden_states[~image_gen_indicators])
_hidden_states[image_gen_indicators] = self.norm_mot_gen(hidden_states[image_gen_indicators])
hidden_states = _hidden_states
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
)
@auto_docstring
class Qwen3MoeForCausalLM(Qwen3MoePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
def __init__(self, config: NEOMoELLMConfig):
super().__init__(config)
self.model = Qwen3MoeModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
indexes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
indexes=indexes,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=hidden_states,
attentions=outputs.attentions,
)
__all__ = [
"Qwen3MoeForCausalLM",
"Qwen3MoeModel",
"Qwen3MoePreTrainedModel",
"Qwen3MoeDecoderLayer",
"Qwen3MoeSparseMoeBlock",
"Qwen3MoeMLP",
]
from __future__ import annotations
import math
import torch
import torchvision.transforms as T
from PIL import Image
SYSTEM_MESSAGE_FOR_GEN = (
"You are an image generation and editing assistant that accurately understands and executes "
"user intent.\n\nYou support two modes:\n\n1. Think Mode:\nIf the task requires reasoning, you "
"MUST start with a <think></think> block. Put all reasoning inside the block using plain text. "
"DO NOT include any image tags. Keep it reasonable and directly useful for producing the final "
"image.\n\n2. Non-Think Mode:\nIf no reasoning is needed, directly produce the final image.\n\n"
"Task Types:\n\nA. Text-to-Image Generation:\n"
"- Generate a high-quality image based on the user's description.\n"
"- Ensure visual clarity, semantic consistency, and completeness.\n"
"- DO NOT introduce elements that contradict or override the user's intent.\n\n"
"B. Image Editing:\n"
"- Use the provided image(s) as input or reference for modification or transformation.\n"
"- The result can be an edited image or a new image based on the reference(s).\n"
"- Preserve all unspecified attributes unless explicitly changed.\n\n"
"General Rules:\n"
"- For any visible text in the image, follow the language specified for the rendered text in "
"the user's description, not the language of the prompt. If no language is specified, use the "
"user's input language."
)
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def round_by_factor(number: float, factor: int) -> int:
"""Returns the closest integer to `number` that is divisible by `factor`."""
return round(number / factor) * factor
def ceil_by_factor(number: float, factor: int) -> int:
"""Returns the smallest integer >= `number` that is divisible by `factor`."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: float, factor: int) -> int:
"""Returns the largest integer <= `number` that is divisible by `factor`."""
return math.floor(number / factor) * factor
def smart_resize(
height: int,
width: int,
factor: int = 32,
min_pixels: int = 65536,
max_pixels: int = 4194304,
) -> tuple[int, int]:
"""Rescale so that H/W are divisible by `factor` and total pixels ∈ [min, max].
Copied from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py#L60
"""
if max(height, width) / min(height, width) > 200:
raise ValueError(
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = max(factor, floor_by_factor(height / beta, factor))
w_bar = max(factor, floor_by_factor(width / beta, factor))
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def dynamic_preprocess_native_resolution(
image: Image.Image,
size_factor: int = 32,
min_pixels: int = 65536,
max_pixels: int = 4194304,
**_kwargs,
) -> Image.Image:
width, height = image.size
resized_height, resized_width = smart_resize(
height,
width,
factor=size_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
return image.resize((resized_width, resized_height))
def preprocess_pixel_values(pixel_values: torch.Tensor, patch_size: int = 16):
c, h, w = pixel_values.shape
grid_h = h // patch_size
grid_w = w // patch_size
flatten_pixel_values = (
pixel_values.view(c, grid_h, patch_size, grid_w, patch_size)
.permute(1, 3, 0, 2, 4) # [grid_h, grid_w, c, patch_size, patch_size]
.reshape(grid_h * grid_w, c * patch_size ** 2)
)
grid_hw = torch.tensor([[grid_h, grid_w]], device=pixel_values.device)
return flatten_pixel_values, grid_hw
def get_contrasting_background(image: Image.Image):
"""Return a background color for RGBA->RGB conversion, or ``None`` to use default.
The original Neo_Unify implementation computed a contrasting background
from the alpha channel. For this open-source release we fall back to a
plain white background; callers that need the smarter behavior can override
this function.
"""
del image
return (255, 255, 255)
def load_image_native(
image,
patch_size: int = 16,
downsample_ratio: float = 0.5,
min_pixels: int = 65536,
max_pixels: int = 4194304,
upscale: bool = False,
):
"""Load and preprocess an image: RGB convert, smart-resize, normalize, patchify."""
if not isinstance(image, Image.Image):
image = Image.open(image)
if image.mode == "RGBA":
bg_color = get_contrasting_background(image)
if bg_color:
background = Image.new("RGB", image.size, bg_color)
background.paste(image, mask=image.split()[3])
image = background.convert("RGB")
else:
image = image.convert("RGB")
else:
image = image.convert("RGB")
if upscale:
image = image.resize((image.width * 2, image.height * 2), Image.BILINEAR)
transform = T.Compose(
[
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.ToTensor(),
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
]
)
new_image = dynamic_preprocess_native_resolution(
image,
size_factor=int(patch_size // downsample_ratio),
min_pixels=min_pixels,
max_pixels=max_pixels,
)
pixel_values, grid_hw = preprocess_pixel_values(
transform(new_image).to(torch.float32), patch_size=patch_size
)
return pixel_values, grid_hw
from .enhancer import (
DEFAULT_BACKEND,
DEFAULT_ENDPOINT,
DEFAULT_MODEL,
DEFAULT_STYLE,
PromptEnhancer,
make_adapter_from_env,
)
__all__ = [
"DEFAULT_BACKEND",
"DEFAULT_ENDPOINT",
"DEFAULT_MODEL",
"DEFAULT_STYLE",
"PromptEnhancer",
"make_adapter_from_env",
]
"""System-prompt templates for LLM-based prompt enhancement.
Templates live as ``.md`` files under ``sensenova_u1/prompt_enhance/templates/``
(makes them easy to edit, diff and swap out without touching Python). They are
loaded once at import time via ``importlib.resources``.
Adding a new style is a two-step change:
1. Drop a new ``*.md`` under ``templates/``.
2. Add its stem to :data:`AVAILABLE_STYLES` below.
"""
from __future__ import annotations
from importlib import resources
AVAILABLE_STYLES = ("infographic",)
"""Styles currently shipped with the package."""
_TEMPLATES_PACKAGE = "sensenova_u1.prompt_enhance.templates"
def load_system_prompt(style: str) -> str:
"""Load the system-prompt ``.md`` for ``style`` from package data.
Raises:
ValueError: If ``style`` is not in :data:`AVAILABLE_STYLES`.
"""
if style not in AVAILABLE_STYLES:
raise ValueError(f"Unknown enhance style {style!r}; supported: {AVAILABLE_STYLES}")
return resources.files(_TEMPLATES_PACKAGE).joinpath(f"{style}_system.md").read_text(encoding="utf-8")
from .anthropic_adapter import AnthropicVlmAdapter
from .chat_completions_adapter import ChatCompletionsVlmAdapter
from .vlm_adapter import VlmAdapter
__all__ = ["AnthropicVlmAdapter", "ChatCompletionsVlmAdapter", "VlmAdapter"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment