Commit 4a457188 authored by raojy's avatar raojy
Browse files

fix

parent a570aeea
# LLaDA2.0-Uni ComfyUI Nodes
Custom ComfyUI nodes for [LLaDA 2.0-Uni](https://huggingface.co/inclusionAI/LLaDA2.0-Uni) — a unified multimodal diffusion language model supporting **text-to-image generation**, **image understanding (VQA)**, and **image editing**.
## Installation
> ⚠️ These nodes depend on the `encoder/` and `decoder/` modules in the project root. Do **not** copy `apps/comfyui` in isolation — the full repository must be present and the relative path `apps/comfyui` must be preserved.
### Option 1: Clone + symlink (recommended)
```bash
# 1. Clone the full project
git clone https://github.com/inclusionAI/LLaDA2.0-Uni.git
# 2. Symlink into ComfyUI's custom_nodes
cd /path/to/ComfyUI/custom_nodes
ln -s /path/to/LLaDA2.0-Uni/apps/comfyui ./LLaDA2Uni
```
### Option 2: One-line installer
```bash
bash /path/to/LLaDA2.0-Uni/apps/comfyui/install.sh /path/to/ComfyUI
```
### Dependencies
```bash
pip install -r apps/comfyui/requirements.txt
pip install flash-attn --no-build-isolation # optional, recommended
```
## Model Weights
In the Loader node, set the model path to either a HuggingFace repo ID or a local directory:
**HuggingFace (auto-download):**
```
inclusionAI/LLaDA2.0-Uni
```
**Local path:**
```
/path/to/LLaDA2.0-Uni
```
Expected directory layout:
```
LLaDA2.0-Uni/
├── config.json # LLM config
├── model-*.safetensors # LLM weights
├── tokenizer.json
├── decoder/
│ ├── config.json
│ └── model.safetensors # diffusion decoder
├── decoder-turbo/
│ ├── config.json
│ └── model.safetensors # turbo decoder (8-step)
├── vae/
│ └── diffusion_pytorch_model.safetensors
└── image_tokenizer/
├── config.json
├── preprocessor_config.json
├── model.safetensors # SigLIP-VQ weights
└── sigvq_embedding.pt
```
## Nodes
| Node | Description |
|------|-------------|
| **LLaDA2.0_Uni Loader** | Load the model (Flash Attention / SDPA, optional CPU offload) |
| **LLaDA2.0_Uni Text-to-Image** | Generate VQ image tokens from a text prompt (supports thinking mode) |
| **LLaDA2.0_Uni Image Understanding** | Visual question answering |
| **LLaDA2.0_Uni Image Editing** | Edit an image with a text instruction |
| **LLaDA2.0_Uni Token Decoder** | Decode VQ tokens to pixels (turbo or normal mode) |
| **LLaDA2.0_Uni Unload Model** | Manually free VRAM |
## Example Workflows
### Text-to-Image
```
Loader → Text-to-Image → Token Decoder → Preview Image
```
### Image Understanding
```
Load Image + Loader → Image Understanding → Show Text
```
### Image Editing
```
Load Image + Loader → Image Editing → Token Decoder → Preview Image
```
## Parameters
### Loader
- `model_path` — HuggingFace repo ID or local directory
- `attention``flash_attn` (recommended) or `sdpa`
- `dtype``bf16` (recommended) or `fp8`
- `offload` — enable CPU offload for limited VRAM
- `device``cuda` or `cpu`
### Text-to-Image
- `prompt` — text description
- `width` / `height` — output resolution
- `steps` — LLM denoising steps (8–32)
- `cfg_scale` — classifier-free guidance scale
- `mode``standard` or `thinking`
- `seed` — random seed (`-1` = random)
- `block_length` — block size for block-wise denoising
### Token Decoder
- `decode_mode``decoder-turbo` (fast, 8 steps) or `normal` (50 steps)
- `decoder_steps` — number of steps when using `normal` mode
- `resolution_multiplier` — upscale factor (typically `2`)
- `unload_after` — release decoder VRAM after decoding (set `False` to keep cached for faster repeated decodes)
## License
Same as the parent project. See the repository root for details.
"""
ComfyUI Custom Nodes for LLaDA2.0_Uni
Unified multimodal: Text-to-Image, Image Understanding (VQA), Image Editing
This node package lives inside LLaDA2.0_Uni/apps/comfyui/ and imports
encoder/decoder from the parent project directly.
"""
from .nodes import (
LLaDA2UniLoader,
LLaDA2UniTextToImage,
LLaDA2UniImageUnderstanding,
LLaDA2UniImageEditing,
LLaDA2UniImageDecode,
LLaDA2UniUnloadModel,
)
NODE_CLASS_MAPPINGS = {
"LLaDA2UniLoader": LLaDA2UniLoader,
"LLaDA2UniTextToImage": LLaDA2UniTextToImage,
"LLaDA2UniImageUnderstanding": LLaDA2UniImageUnderstanding,
"LLaDA2UniImageEditing": LLaDA2UniImageEditing,
"LLaDA2UniImageDecode": LLaDA2UniImageDecode,
"LLaDA2UniUnloadModel": LLaDA2UniUnloadModel,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LLaDA2UniLoader": "LLaDA2.0_Uni Loader",
"LLaDA2UniTextToImage": "LLaDA2.0_Uni Text-to-Image",
"LLaDA2UniImageUnderstanding": "LLaDA2.0_Uni Image Understanding",
"LLaDA2UniImageEditing": "LLaDA2.0_Uni Image Editing",
"LLaDA2UniImageDecode": "LLaDA2.0_Uni Token Decoder",
"LLaDA2UniUnloadModel": "LLaDA2.0_Uni Unload Model",
}
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
#!/bin/bash
# LLaDA2.0_Uni ComfyUI 节点快速安装脚本
set -e
echo "=========================================="
echo "LLaDA2.0_Uni ComfyUI 节点安装"
echo "=========================================="
# 检查 ComfyUI 路径
if [ -z "$1" ]; then
echo "用法: $0 /path/to/ComfyUI"
echo ""
echo "示例:"
echo " $0 ~/ComfyUI"
exit 1
fi
COMFYUI_PATH="$1"
CUSTOM_NODES_DIR="$COMFYUI_PATH/custom_nodes"
if [ ! -d "$COMFYUI_PATH" ]; then
echo "❌ ComfyUI 路径不存在: $COMFYUI_PATH"
exit 1
fi
if [ ! -d "$CUSTOM_NODES_DIR" ]; then
echo "❌ custom_nodes 目录不存在: $CUSTOM_NODES_DIR"
exit 1
fi
# 获取当前脚本所在目录(即 apps/comfyui)
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
# 验证项目结构完整(需要上层的 encoder/ 和 decoder/)
PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)"
if [ ! -d "$PROJECT_ROOT/encoder" ] || [ ! -d "$PROJECT_ROOT/decoder" ]; then
echo "❌ 项目结构不完整,缺少 encoder/ 或 decoder/ 目录"
echo " 项目根目录: $PROJECT_ROOT"
echo ""
echo "请确保通过 git clone 获取了完整项目:"
echo " git clone https://github.com/inclusionAI/LLaDA2.0-Uni.git"
exit 1
fi
echo "📂 ComfyUI 路径: $COMFYUI_PATH"
echo "📂 节点源路径: $SCRIPT_DIR"
echo "📂 项目根目录: $PROJECT_ROOT"
echo ""
# 创建软链接
TARGET_DIR="$CUSTOM_NODES_DIR/LLaDA2Uni"
if [ -e "$TARGET_DIR" ]; then
echo "⚠️ 目标已存在: $TARGET_DIR"
read -p "是否删除并重新创建?(y/n) " -n 1 -r
echo
if [[ $REPLY =~ ^[Yy]$ ]]; then
rm -rf "$TARGET_DIR"
else
echo "❌ 取消安装"
exit 1
fi
fi
ln -s "$SCRIPT_DIR" "$TARGET_DIR"
echo "✅ 创建软链接: $TARGET_DIR -> $SCRIPT_DIR"
# 安装依赖
echo ""
echo "📦 安装 Python 依赖..."
pip install -q -r "$SCRIPT_DIR/requirements.txt"
echo "✅ 依赖安装完成"
# 可选:Flash Attention
echo ""
read -p "是否安装 Flash Attention 2?(推荐,但编译较慢) (y/n) " -n 1 -r
echo
if [[ $REPLY =~ ^[Yy]$ ]]; then
echo "📦 安装 Flash Attention 2..."
pip install flash-attn --no-build-isolation || echo "⚠️ Flash Attention 安装失败,可以继续使用 SDPA"
fi
echo ""
echo "=========================================="
echo "✅ 安装完成!"
echo "=========================================="
echo ""
echo "下一步:"
echo "1. 启动 ComfyUI:"
echo " cd $COMFYUI_PATH && python main.py"
echo ""
echo "2. 在浏览器打开: http://localhost:8188"
echo ""
echo "3. 右键 → Add Node → LLaDA2.0_Uni"
echo ""
echo "4. 在 Loader 节点中设置模型路径:"
echo " inclusionAI/LLaDA2.0-Uni"
echo " (首次使用会自动从 HuggingFace 下载,也可填写本地路径)"
echo ""
echo "详细文档: $SCRIPT_DIR/README.md"
"""
Model manager for LLaDA2.0_Uni ComfyUI nodes.
Handles loading/unloading, attention backends, CPU offload, VRAM management,
and decoder model caching.
"""
import torch
import gc
import sys
import os
from typing import Dict, Any
# ── Add project root to sys.path so encoder/ and decoder/ are importable ──
_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
# ── Global state ──
_LLM_MODEL = None
_LLM_TOKENIZER = None
_IMAGE_TOKENIZER = None
_MODEL_PATH = None
_ATTENTION = None
_DEVICE = "cuda"
_OFFLOAD = False
_DTYPE = "bf16"
# Decoder cache (module-level)
_SIGVQ_MODEL = None
_DIFF_MODEL = None
_DIFF_MODE = None
_DIFF_CONFIG = None
_VAE_MODEL = None
_DECODER_MODEL_PATH = None
def _resolve_torch_dtype(dtype: str):
if dtype == "bf16":
return torch.bfloat16
if dtype == "fp8":
print("[LLaDA2.0_Uni] FP8 mode: using bf16 compute dtype for compatibility.")
return torch.bfloat16
raise ValueError(f"Unsupported dtype: {dtype}")
# ═══════════════════════════════════════════════════════════════
# LLM Loading
# ═══════════════════════════════════════════════════════════════
def load_llm(model_path: str, device: str = "cuda", attention: str = "flash_attn",
offload: bool = False, dtype: str = "bf16"):
"""Load the dLLM-MoE backbone. Returns (model, tokenizer)."""
global _LLM_MODEL, _LLM_TOKENIZER, _MODEL_PATH, _ATTENTION, _DEVICE, _OFFLOAD, _DTYPE
_ATTENTION = attention
_DEVICE = device
_OFFLOAD = offload
_DTYPE = dtype
if _LLM_MODEL is not None and _MODEL_PATH == model_path and _DTYPE == dtype:
return _LLM_MODEL, _LLM_TOKENIZER
unload_llm()
from transformers import AutoModelForCausalLM, AutoTokenizer
attn_kwargs = {"trust_remote_code": True}
if attention == "sdpa":
attn_kwargs["attn_implementation"] = "sdpa"
if offload:
attn_kwargs["device_map"] = "auto"
attn_kwargs["max_memory"] = {0: "20GiB", "cpu": "80GiB"}
attn_kwargs["offload_folder"] = "offload_cache"
attn_kwargs["torch_dtype"] = _resolve_torch_dtype(dtype)
else:
attn_kwargs["device_map"] = device
attn_kwargs["torch_dtype"] = _resolve_torch_dtype(dtype)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, **attn_kwargs).eval()
model.tokenizer = tokenizer
_LLM_MODEL = model
_LLM_TOKENIZER = tokenizer
_MODEL_PATH = model_path
return model, tokenizer
# ═══════════════════════════════════════════════════════════════
# Image Tokenizer
# ═══════════════════════════════════════════════════════════════
def get_image_tokenizer(model_path: str, device: str = "cuda"):
"""Load the SigLIP-VQ image tokenizer."""
global _IMAGE_TOKENIZER
if _IMAGE_TOKENIZER is None:
from encoder.image_tokenizer import ImageTokenizer
_IMAGE_TOKENIZER = ImageTokenizer(model_path=model_path, device=device)
return _IMAGE_TOKENIZER
# ═══════════════════════════════════════════════════════════════
# Decoder (with caching)
# ═══════════════════════════════════════════════════════════════
def decode_tokens(token_ids, h, w, model_path: str, device: str = "cuda",
num_steps: int = 50, decode_mode: str = "normal",
resolution_multiplier: int = 2, progress_callback=None):
"""Decode VQ tokens to PIL image, with model caching."""
global _SIGVQ_MODEL, _DIFF_MODEL, _DIFF_MODE, _DIFF_CONFIG, _VAE_MODEL, _DECODER_MODEL_PATH
import json
import torch.nn.functional as F
from tqdm import tqdm
from torchvision.transforms.functional import to_pil_image
from diffusers import AutoencoderKL
from safetensors.torch import load_file
from decoder.sigvq import SigVQ
from decoder.decoder_model import ZImageTransformer2DModel
from decoder.transport import create_transport, Sampler
dtype = torch.bfloat16
# ── Stage 1: SigVQ → semantic features (cached) ──
sigvq_path = os.path.join(model_path, "image_tokenizer", "sigvq_embedding.pt")
if _SIGVQ_MODEL is None or _DECODER_MODEL_PATH != model_path:
extractor = SigVQ(vocab_size=16384, inner_dim=4096).to(device, dtype=dtype)
extractor.load_state_dict(torch.load(sigvq_path, map_location=device, weights_only=True))
extractor.eval()
_SIGVQ_MODEL = extractor
_DECODER_MODEL_PATH = model_path
print("[LLaDA2.0_Uni Decoder] SigVQ loaded and cached.")
th = h * 16 * resolution_multiplier
tw = w * 16 * resolution_multiplier
tok = torch.tensor(token_ids).view(1, 1, h, w).float().to(device)
up = F.interpolate(tok, scale_factor=2, mode="nearest").long().view(1, -1)
cap_pos = [_SIGVQ_MODEL(up).squeeze(0)]
cap_neg = [torch.zeros_like(cap_pos[0])]
# ── Stage 2: Diffusion ODE sampling (cached) ──
if decode_mode == "decoder-turbo":
decoder_dir = os.path.join(model_path, "decoder-turbo")
else:
decoder_dir = os.path.join(model_path, "decoder")
if _DIFF_MODEL is None or _DIFF_MODE != decode_mode or _DECODER_MODEL_PATH != model_path:
# Free old model if mode changed
if _DIFF_MODEL is not None:
del _DIFF_MODEL
gc.collect()
torch.cuda.empty_cache()
config_path = os.path.join(decoder_dir, "config.json")
with open(config_path) as f:
cfg = json.load(f)
cfg["axes_lens"] = [32768, 1024, 1024]
cfg["cap_feat_dim"] = 4096
with torch.device("meta"):
diff_model = ZImageTransformer2DModel(**cfg)
ckpt = os.path.join(decoder_dir, "model.safetensors")
diff_model.load_state_dict(load_file(ckpt, device=str(device)), assign=True)
diff_model = diff_model.to(dtype=dtype).eval()
_DIFF_MODEL = diff_model
_DIFF_MODE = decode_mode
_DIFF_CONFIG = cfg
print(f"[LLaDA2.0_Uni Decoder] Diffusion model ({decode_mode}) loaded and cached.")
cfg = _DIFF_CONFIG
# Create model function for sampling
n = len(cap_pos)
doubled = cap_pos + cap_neg
cfg_scale = 0.0 if decode_mode == "decoder-turbo" else 1.0
patch_size = cfg.get("all_patch_size", (2,))[0]
f_patch_size = cfg.get("all_f_patch_size", (1,))[0]
def model_fn(x, t, **kw):
t_t = torch.tensor([t], device=x.device, dtype=torch.float32) if not isinstance(t, torch.Tensor) else t.float()
if t_t.dim() == 0: t_t = t_t.unsqueeze(0)
if t_t.shape[0] == 1 and x.shape[0] > 1: t_t = t_t.expand(x.shape[0])
if cfg_scale > 0:
out = _DIFF_MODEL(x=list(x.to(dtype).repeat(2, 1, 1, 1, 1).unbind(0)), t=t_t.repeat(2),
cap_feats=doubled, patch_size=patch_size, f_patch_size=f_patch_size, return_dict=False)
pos, neg = out[0][:n], out[0][n:]
res = []
for p, ng in zip(pos, neg):
p, ng = p.float(), ng.float()
pred = p + cfg_scale * (p - ng)
on, nn_ = torch.linalg.vector_norm(p), torch.linalg.vector_norm(pred)
if nn_ > on:
pred *= on / nn_
res.append(pred)
return torch.stack(res)
out = _DIFF_MODEL(x=list(x.to(dtype).unbind(0)), t=t_t, cap_feats=cap_pos,
patch_size=patch_size, f_patch_size=f_patch_size, return_dict=False)
return torch.stack([o.float() for o in out[0]])
z = torch.randn([1, 16, 1, 2 * (th // 16), 2 * (tw // 16)], device=device)
sampler = Sampler(create_transport("Linear", "velocity", None))
sample_fn = sampler.sample_ode(
sampling_method="euler", num_steps=num_steps,
atol=1e-6, rtol=1e-3, reverse=False, time_shifting_factor=6,
stochast_ratio=1.0 if decode_mode == "decoder-turbo" else 0.0)
step_counter = [0]
if progress_callback is not None:
def wrapped(x, t, **kw):
step_counter[0] += 1
progress_callback(step_counter[0], num_steps)
return model_fn(x, t, **kw)
else:
pbar = tqdm(total=num_steps, desc="Decoding", leave=False)
def wrapped(x, t, **kw):
pbar.update(1)
return model_fn(x, t, **kw)
with torch.inference_mode():
samples = sample_fn(z, wrapped)[-1].squeeze(2)
if progress_callback is None:
pbar.close()
# ── Stage 3: VAE decode (cached) ──
vae_dir = os.path.join(model_path, "vae")
if _VAE_MODEL is None or _DECODER_MODEL_PATH != model_path:
_VAE_MODEL = AutoencoderKL.from_pretrained(vae_dir, torch_dtype=dtype).to(device).eval()
print("[LLaDA2.0_Uni Decoder] VAE loaded and cached.")
with torch.inference_mode():
s = samples.to(dtype)
s = (s / _VAE_MODEL.config.scaling_factor) + _VAE_MODEL.config.shift_factor
px = ((_VAE_MODEL.decode(s, return_dict=False)[0] + 1) / 2).clamp_(0, 1)
return to_pil_image(px[0].float())
# ═══════════════════════════════════════════════════════════════
# Unload functions
# ═══════════════════════════════════════════════════════════════
def unload_llm():
"""Unload LLM backbone to free VRAM."""
global _LLM_MODEL, _LLM_TOKENIZER
if _LLM_MODEL is not None:
del _LLM_MODEL, _LLM_TOKENIZER
_LLM_MODEL = None
_LLM_TOKENIZER = None
gc.collect()
torch.cuda.empty_cache()
def unload_decoder():
"""Unload all decoder components from VRAM."""
global _SIGVQ_MODEL, _DIFF_MODEL, _DIFF_MODE, _DIFF_CONFIG, _VAE_MODEL, _DECODER_MODEL_PATH
for obj in (_SIGVQ_MODEL, _DIFF_MODEL, _VAE_MODEL):
if obj is not None:
del obj
_SIGVQ_MODEL = None
_DIFF_MODEL = None
_DIFF_MODE = None
_DIFF_CONFIG = None
_VAE_MODEL = None
_DECODER_MODEL_PATH = None
gc.collect()
torch.cuda.empty_cache()
def unload_image_tokenizer():
"""Unload image tokenizer."""
global _IMAGE_TOKENIZER
if _IMAGE_TOKENIZER is not None:
del _IMAGE_TOKENIZER
_IMAGE_TOKENIZER = None
gc.collect()
torch.cuda.empty_cache()
def unload_all():
"""Unload everything. Call this to free all VRAM."""
unload_llm()
unload_decoder()
unload_image_tokenizer()
"""
ComfyUI Custom Nodes for LLaDA2.0_Uni
Supports: Text-to-Image, Image Understanding, Image Editing, Decode, Unload
Features:
- Attention backend selection: Flash Attention, SDPA
- CPU offload mode for limited VRAM
- Auto VRAM cleanup after inference
- Decoder model caching (60-67% speedup on repeated decodes)
- Manual unload node
"""
import torch
import numpy as np
from PIL import Image
from typing import Optional
import os
import shutil
import subprocess
# ═══════════════════════════════════════════════════════════════
# LLaDA2.0_Uni Loader
# ═══════════════════════════════════════════════════════════════
class LLaDA2UniLoader:
"""Load LLaDA2.0_Uni model with attention & offload options."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model_path": ("STRING", {"default": "inclusionAI/LLaDA2.0-Uni"}),
"attention": (["flash_attn", "sdpa"],),
"dtype": (["bf16", "fp8"],),
"offload": ("BOOLEAN", {"default": False}),
"device": (["cuda", "cpu"],),
}
}
RETURN_TYPES = ("LLADA2UNI_MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "load"
CATEGORY = "LLaDA2.0_Uni"
DESCRIPTION = "Load LLaDA2.0_Uni model from specified path."
def load(self, model_path, attention, dtype, offload, device):
if not isinstance(device, str) or not device:
device = "cuda"
# Validate model path
if not os.path.isdir(model_path):
raise ValueError(f"Model path does not exist: {model_path}")
from .model_manager import load_llm
model, tokenizer = load_llm(model_path, device=device,
attention=attention, offload=offload, dtype=dtype)
return ({
"model_path": model_path,
"device": device,
"attention": attention,
"dtype": dtype,
"offload": offload,
},)
@classmethod
def IS_CHANGED(cls, **kwargs):
return float("NaN")
# ═══════════════════════════════════════════════════════════════
# Text-to-Image
# ═══════════════════════════════════════════════════════════════
class LLaDA2UniTextToImage:
"""Generate image tokens from text prompt."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("LLADA2UNI_MODEL",),
"prompt": ("STRING", {"multiline": True, "default": ""}),
"width": ("INT", {"default": 1024, "min": 256, "max": 2048, "step": 64}),
"height": ("INT", {"default": 1024, "min": 256, "max": 2048, "step": 64}),
"steps": ("INT", {"default": 8, "min": 1, "max": 32}),
"cfg_scale": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
},
"optional": {
"mode": (["standard", "thinking"],),
"thinking_steps": ("INT", {"default": 32, "min": 1, "max": 64}),
"thinking_length": ("INT", {"default": 4096, "min": 512, "max": 8192, "step": 512}),
"seed": ("INT", {"default": -1, "min": -1, "max": 2**32 - 1}),
"block_length": ("INT", {"default": 32, "min": 1, "max": 128}),
"unload_after": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("LLADA2UNI_TOKENS", "STRING",)
RETURN_NAMES = ("tokens", "thinking_output",)
FUNCTION = "generate"
CATEGORY = "LLaDA2.0_Uni"
DESCRIPTION = "Generate VQ image tokens from text prompt."
def generate(self, model, prompt, width, height, steps, cfg_scale,
mode="standard", thinking_steps=32, thinking_length=4096,
seed=-1, block_length=32, unload_after=True):
# Set seed for reproducibility
if seed >= 0:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
from .model_manager import load_llm, unload_llm
llm, _ = load_llm(model["model_path"], model["device"],
model.get("attention", "flash_attn"),
model.get("offload", False),
model.get("dtype", "bf16"))
# Generate image tokens
if mode == "thinking":
result = llm.generate_image(
prompt, image_h=height, image_w=width,
mode="thinking",
steps=steps, cfg_scale=cfg_scale,
thinking_steps=thinking_steps,
thinking_gen_length=thinking_length,
block_length=block_length,
)
thinking_text = result.get("thinking", "")
else:
result = llm.generate_image(
prompt, image_h=height, image_w=width,
steps=steps, cfg_scale=cfg_scale,
block_length=block_length,
)
thinking_text = ""
token_data = {
"token_ids": result["token_ids"],
"h": result["h"],
"w": result["w"],
"model_path": model["model_path"],
}
# Auto VRAM cleanup
if unload_after:
unload_llm()
return (token_data, thinking_text,)
# ═══════════════════════════════════════════════════════════════
# Image Understanding (VQA)
# ═══════════════════════════════════════════════════════════════
class LLaDA2UniImageUnderstanding:
"""Understand/describe an image using VQA."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("LLADA2UNI_MODEL",),
"image": ("IMAGE",),
"question": ("STRING", {"multiline": True, "default": "Describe this image in detail."}),
},
"optional": {
"gen_steps": ("INT", {"default": 32, "min": 1, "max": 64}),
"gen_length": ("INT", {"default": 2048, "min": 256, "max": 8192, "step": 256}),
"block_length": ("INT", {"default": 32, "min": 1, "max": 128}),
"unload_after": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("response",)
FUNCTION = "understand"
CATEGORY = "LLaDA2.0_Uni"
DESCRIPTION = "Ask questions about an image (VQA)."
def understand(self, model, image, question, gen_steps=32, gen_length=2048,
block_length=32, unload_after=True):
from .model_manager import load_llm, get_image_tokenizer, unload_llm
from decoder.smart_img_process import smart_resize_images
llm, _ = load_llm(model["model_path"], model["device"],
model.get("attention", "flash_attn"),
model.get("offload", False),
model.get("dtype", "bf16"))
# ComfyUI tensor (B,H,W,C) → PIL
img_np = (image[0].cpu().numpy() * 255).astype(np.uint8)
pil_image = Image.fromarray(img_np)
# Encode image
image_tokenizer = get_image_tokenizer(model["model_path"], model["device"])
pil_image = smart_resize_images([pil_image])[0]
info = image_tokenizer.encode_with_info(pil_image)
image_tokens = [x + llm.config.image_token_offset for x in info["token_ids"]]
_, h, w = info["grid_thw"]
# Understand the image
response = llm.understand_image(
image_tokens, h, w,
question=question,
steps=gen_steps,
gen_length=gen_length,
block_length=block_length,
)
# Auto VRAM cleanup
if unload_after:
unload_llm()
return (response,)
# ═══════════════════════════════════════════════════════════════
# Image Editing
# ═══════════════════════════════════════════════════════════════
class LLaDA2UniImageEditing:
"""Edit an image based on a text instruction."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("LLADA2UNI_MODEL",),
"image": ("IMAGE",),
"instruction": ("STRING", {"multiline": True, "default": ""}),
},
"optional": {
"steps": ("INT", {"default": 8, "min": 1, "max": 32}),
"cfg_text_scale": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.1}),
"cfg_image_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.1}),
"block_length": ("INT", {"default": 32, "min": 1, "max": 128}),
"seed": ("INT", {"default": -1, "min": -1, "max": 2**32 - 1}),
"unload_after": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("LLADA2UNI_TOKENS",)
RETURN_NAMES = ("tokens",)
FUNCTION = "edit"
CATEGORY = "LLaDA2.0_Uni"
DESCRIPTION = "Edit an image with a text instruction. Connect output to Decode node."
def edit(self, model, image, instruction, steps=8, cfg_text_scale=4.0,
cfg_image_scale=0.0, block_length=32, seed=-1, unload_after=True):
# Set seed for reproducibility
if seed >= 0:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
from .model_manager import load_llm, get_image_tokenizer, unload_llm
from decoder.utils import generate_crop_size_list, var_center_crop
llm, _ = load_llm(model["model_path"], model["device"],
model.get("attention", "flash_attn"),
model.get("offload", False),
model.get("dtype", "bf16"))
# ComfyUI tensor → PIL
img_np = (image[0].cpu().numpy() * 255).astype(np.uint8)
pil_image = Image.fromarray(img_np)
# Encode source image (preserve aspect ratio via var_center_crop)
image_tokenizer = get_image_tokenizer(model["model_path"], model["device"])
crop_size_list = generate_crop_size_list((512 // 32) ** 2, 32)
pil_image = var_center_crop(pil_image, crop_size_list=crop_size_list)
info = image_tokenizer.encode_with_info(pil_image)
image_tokens = [x + llm.config.image_token_offset for x in info["token_ids"]]
_, h, w = info["grid_thw"]
# Edit the image
result = llm.edit_image(
image_tokens, h, w, instruction,
steps=steps,
block_length=block_length,
cfg_text_scale=cfg_text_scale,
cfg_image_scale=cfg_image_scale,
)
token_data = {
"token_ids": result["token_ids"],
"h": result["h"],
"w": result["w"],
"model_path": model["model_path"],
}
# Auto VRAM cleanup
if unload_after:
unload_llm()
return (token_data,)
# ═══════════════════════════════════════════════════════════════
# Token Decoder
# ═══════════════════════════════════════════════════════════════
class LLaDA2UniImageDecode:
"""Decode VQ tokens from LLaDA2.0_Uni into a pixel image."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"tokens": ("LLADA2UNI_TOKENS",),
},
"optional": {
"decode_mode": (["decoder-turbo", "normal"],),
"decoder_steps": ("INT", {"default": 50, "min": 1, "max": 100}),
"resolution_multiplier": ("INT", {"default": 2, "min": 1, "max": 4}),
"unload_after": ("BOOLEAN", {"default": False}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
FUNCTION = "decode"
CATEGORY = "LLaDA2.0_Uni"
DESCRIPTION = "Decode VQ tokens to pixel image. decoder-turbo is ~10× faster (8 steps vs 50)."
def decode(self, tokens, decode_mode="decoder-turbo", decoder_steps=50,
resolution_multiplier=2, unload_after=False):
from .model_manager import decode_tokens, unload_all
num_steps = 8 if decode_mode == "decoder-turbo" else decoder_steps
# ComfyUI progress bar integration
progress_callback = None
try:
from comfy.utils import ProgressBar
pbar = ProgressBar(num_steps)
def progress_callback(step, total):
pbar.update(1)
except ImportError:
pass
pil_image = decode_tokens(
tokens["token_ids"],
tokens["h"],
tokens["w"],
tokens["model_path"],
device="cuda",
num_steps=num_steps,
decode_mode=decode_mode,
resolution_multiplier=resolution_multiplier,
progress_callback=progress_callback,
)
# PIL → ComfyUI tensor (1, H, W, 3) float32 [0,1]
img_np = np.array(pil_image).astype(np.float32) / 255.0
tensor = torch.from_numpy(img_np).unsqueeze(0)
# Auto VRAM cleanup after full pipeline
if unload_after:
unload_all()
return (tensor,)
# ═══════════════════════════════════════════════════════════════
# Manual Unload
# ═══════════════════════════════════════════════════════════════
class LLaDA2UniUnloadModel:
"""Manually unload all LLaDA2.0_Uni components and free VRAM."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {},
"optional": {},
}
RETURN_TYPES = ()
OUTPUT_NODE = True
FUNCTION = "unload"
CATEGORY = "LLaDA2.0_Uni"
DESCRIPTION = "Unload all LLaDA2.0_Uni models from VRAM."
def unload(self):
from .model_manager import unload_all
unload_all()
print("[LLaDA2.0_Uni] ✅ All models unloaded, VRAM freed")
return ()
torch>=2.4.0
torchvision>=0.19.0
transformers>=4.45.0
safetensors>=0.4.0
diffusers>=0.30.0
Pillow>=10.0.0
numpy>=1.24.0
tqdm>=4.65.0
accelerate>=0.26.0
torchdiffeq
from .sigvq import SigVQ
from .decoder_model import ZImageTransformer2DModel
from .decode import decode_vq_tokens
from .utils import generate_crop_size_list, var_center_crop
from . import transport
"""Decode VQ token IDs into a PIL Image via SigVQ + ZImage diffusion + VAE."""
import os
import json
import gc
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torchvision.transforms.functional import to_pil_image
from diffusers import AutoencoderKL
from safetensors.torch import load_file
from .sigvq import SigVQ
from .decoder_model import ZImageTransformer2DModel
from .transport import create_transport, Sampler
def _create_decoder_model_fn(model, cap_pos, cap_neg, cfg_scale, patch_size, f_patch_size, dtype):
n = len(cap_pos)
doubled = cap_pos + cap_neg
def fn(x, t, **kw):
t_t = torch.tensor([t], device=x.device, dtype=torch.float32) if not isinstance(t, torch.Tensor) else t.float()
if t_t.dim() == 0: t_t = t_t.unsqueeze(0)
if t_t.shape[0] == 1 and x.shape[0] > 1: t_t = t_t.expand(x.shape[0])
if cfg_scale > 0:
out = model(x=list(x.to(dtype).repeat(2, 1, 1, 1, 1).unbind(0)), t=t_t.repeat(2),
cap_feats=doubled, patch_size=patch_size, f_patch_size=f_patch_size, return_dict=False)
pos, neg = out[0][:n], out[0][n:]
res = []
for p, ng in zip(pos, neg):
p, ng = p.float(), ng.float()
pred = p + cfg_scale * (p - ng)
on, nn_ = torch.linalg.vector_norm(p), torch.linalg.vector_norm(pred)
if nn_ > on:
pred *= on / nn_
res.append(pred)
return torch.stack(res)
out = model(x=list(x.to(dtype).unbind(0)), t=t_t, cap_feats=cap_pos,
patch_size=patch_size, f_patch_size=f_patch_size, return_dict=False)
return torch.stack([o.float() for o in out[0]])
return fn
@torch.inference_mode()
def decode_vq_tokens(token_ids, h, w, model_path, device,
resolution_multiplier=2, num_steps=50,
decode_mode="normal"):
"""
Decode VQ token IDs into a PIL Image.
Args:
token_ids: List of VQ token IDs (without the +157184 offset).
h, w: Semantic grid size (image_pixels // 16).
model_path: Root path of the model directory.
device: torch device.
resolution_multiplier: Upscale factor (2 = 1024px from 512px tokens).
num_steps: ODE sampling steps.
decode_mode: ``"normal"`` uses the standard decoder (default, 50 steps);
``"decoder-turbo"`` uses the distilled decoder (faster, ~8 steps).
Returns:
PIL.Image
"""
dtype = torch.bfloat16
sigvq_path = os.path.join(model_path, "image_tokenizer", "sigvq_embedding.pt")
if decode_mode == "decoder-turbo":
decoder_dir = os.path.join(model_path, "decoder-turbo")
else:
decoder_dir = os.path.join(model_path, "decoder")
vae_dir = os.path.join(model_path, "vae")
# ---------- Stage 1: SigVQ → semantic features ----------
extractor = SigVQ(vocab_size=16384, inner_dim=4096).to(device, dtype=dtype)
extractor.load_state_dict(
torch.load(sigvq_path, map_location=device, weights_only=True))
extractor.eval()
th = h * 16 * resolution_multiplier
tw = w * 16 * resolution_multiplier
tok = torch.tensor(token_ids).view(1, 1, h, w).float().to(device)
up = F.interpolate(tok, scale_factor=2, mode="nearest").long().view(1, -1)
cap_pos = [extractor(up).squeeze(0)]
cap_neg = [torch.zeros_like(cap_pos[0])]
# SigVQ is no longer needed — release immediately
del extractor
gc.collect()
torch.cuda.empty_cache()
# ---------- Stage 2: Diffusion ODE sampling ----------
config_path = os.path.join(decoder_dir, "config.json")
with open(config_path) as f:
cfg = json.load(f)
cfg["axes_lens"] = [32768, 1024, 1024]
cfg["cap_feat_dim"] = 4096
# Build model on meta device, load weights directly to GPU, then tie —
# avoids the ~12 GB peak from holding both random init + loaded weights.
with torch.device("meta"):
diff_model = ZImageTransformer2DModel(**cfg)
ckpt = os.path.join(decoder_dir, "model.safetensors")
diff_model.load_state_dict(load_file(ckpt, device=str(device)), assign=True)
diff_model = diff_model.to(dtype=dtype).eval()
z = torch.randn([1, 16, 1, 2 * (th // 16), 2 * (tw // 16)], device=device)
model_fn = _create_decoder_model_fn(
diff_model, cap_pos, cap_neg,
cfg_scale=0.0 if decode_mode == "decoder-turbo" else 1.0,
patch_size=cfg.get("all_patch_size", (2,))[0],
f_patch_size=cfg.get("all_f_patch_size", (1,))[0],
dtype=dtype)
sampler = Sampler(create_transport("Linear", "velocity", None))
sample_fn = sampler.sample_ode(
sampling_method="euler", num_steps=num_steps,
atol=1e-6, rtol=1e-3, reverse=False, time_shifting_factor=6,
stochast_ratio=1.0 if decode_mode == "decoder-turbo" else 0.0)
pbar = tqdm(total=num_steps, desc="Decoding", leave=False)
def wrapped(x, t, **kw):
pbar.update(1)
return model_fn(x, t, **kw)
samples = sample_fn(z, wrapped)[-1].squeeze(2)
pbar.close()
# Diffusion model is done — release before loading VAE
del diff_model, cap_pos, cap_neg, model_fn
gc.collect()
torch.cuda.empty_cache()
# ---------- Stage 3: VAE decode ----------
vae = AutoencoderKL.from_pretrained(vae_dir, torch_dtype=dtype).to(device).eval()
s = samples.to(dtype)
s = (s / vae.config.scaling_factor) + vae.config.shift_factor
px = ((vae.decode(s, return_dict=False)[0] + 1) / 2).clamp_(0, 1)
del vae
gc.collect()
torch.cuda.empty_cache()
return to_pil_image(px[0].float())
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
from diffusers.models.attention_processor import Attention
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import RMSNorm
from diffusers.utils.torch_utils import maybe_allow_in_graph
try:
from diffusers.models.attention_dispatch import dispatch_attention_fn
_HAS_DISPATCH_ATTENTION = True
except ImportError:
_HAS_DISPATCH_ATTENTION = False
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from flash_attn import flash_attn_func
ADALN_EMBED_DIM = 256
SEQ_MULTI_OF = 32
X_PAD_DIM = 64
class TimestepEmbedder(nn.Module):
def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
super().__init__()
if mid_size is None:
mid_size = out_size
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, mid_size, bias=True),
nn.SiLU(),
nn.Linear(mid_size, out_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
with torch.amp.autocast("cuda", enabled=False):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
weight_dtype = self.mlp[0].weight.dtype
compute_dtype = getattr(self.mlp[0], "compute_dtype", None)
if weight_dtype.is_floating_point:
t_freq = t_freq.to(weight_dtype)
elif compute_dtype is not None:
t_freq = t_freq.to(compute_dtype)
t_emb = self.mlp(t_freq)
return t_emb
class ZSingleStreamAttnProcessor:
"""
Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the
original Z-ImageAttention module.
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
# Apply Norms
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast("cuda", enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in) # todo
if freqs_cis is not None:
query = apply_rotary_emb(query, freqs_cis)
key = apply_rotary_emb(key, freqs_cis)
# Cast to correct dtype
dtype = query.dtype
query, key = query.to(dtype), key.to(dtype)
# From [batch, seq_len] to appropriate mask format
if attention_mask is not None and attention_mask.ndim == 2:
if _HAS_DISPATCH_ATTENTION:
# dispatch_attention_fn expects 4D mask: [batch, 1, 1, seq_len]
attention_mask = attention_mask[:, None, None, :]
else:
# flash_attn: mask out inputs directly
mask_expanded = attention_mask.unsqueeze(-1).unsqueeze(-1) # (B, S, 1, 1)
query = query * mask_expanded
key = key * mask_expanded
value = value * mask_expanded
# Compute joint attention
if _HAS_DISPATCH_ATTENTION:
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
else:
hidden_states = flash_attn_func(
query, key, value,
dropout_p=0.0,
causal=False,
)
# Reshape back
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(dtype)
output = attn.to_out[0](hidden_states)
if len(attn.to_out) > 1: # dropout
output = attn.to_out[1](output)
return output
def select_per_token(
value_noisy: torch.Tensor,
value_clean: torch.Tensor,
noise_mask: torch.Tensor,
seq_len: int,
) -> torch.Tensor:
noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1)
return torch.where(
noise_mask_expanded == 1,
value_noisy.unsqueeze(1).expand(-1, seq_len, -1),
value_clean.unsqueeze(1).expand(-1, seq_len, -1),
)
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def _forward_silu_gating(self, x1, x3):
return F.silu(x1) * x3
def forward(self, x):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
@maybe_allow_in_graph
class ZImageTransformerBlock(nn.Module):
def __init__(
self,
layer_id: int,
dim: int,
n_heads: int,
n_kv_heads: int,
norm_eps: float,
qk_norm: bool,
modulation=True,
):
super().__init__()
self.dim = dim
self.head_dim = dim // n_heads
# Refactored to use diffusers Attention with custom processor
# Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm
self.attention = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=dim // n_heads,
heads=n_heads,
qk_norm="rms_norm" if qk_norm else None,
eps=1e-5,
bias=False,
out_bias=False,
processor=ZSingleStreamAttnProcessor(),
)
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
self.layer_id = layer_id
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
self.modulation = modulation
if modulation:
self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True))
def forward(
self,
x: torch.Tensor,
attn_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor] = None,
noise_mask: Optional[torch.Tensor] = None,
adaln_noisy: Optional[torch.Tensor] = None,
adaln_clean: Optional[torch.Tensor] = None,
):
if self.modulation:
seq_len = x.shape[1]
if noise_mask is not None:
# Per-token modulation: different modulation for noisy/clean tokens
mod_noisy = self.adaLN_modulation(adaln_noisy)
mod_clean = self.adaLN_modulation(adaln_clean)
scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1)
scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1)
gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh()
gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh()
scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy
scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean
scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len)
scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len)
gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len)
gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len)
else:
# Global modulation: same modulation for all tokens (avoid double select)
mod = self.adaLN_modulation(adaln_input)
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2)
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
# Attention block
attn_out = self.attention(
self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis
)
x = x + gate_msa * self.attention_norm2(attn_out)
# FFN block
x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
else:
# Attention block
attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis)
x = x + self.attention_norm2(attn_out)
# FFN block
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
return x
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
)
def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None):
seq_len = x.shape[1]
if noise_mask is not None:
# Per-token modulation
scale_noisy = 1.0 + self.adaLN_modulation(c_noisy)
scale_clean = 1.0 + self.adaLN_modulation(c_clean)
scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len)
else:
# Original global modulation
assert c is not None, "Either c or (c_noisy, c_clean) must be provided"
scale = 1.0 + self.adaLN_modulation(c)
scale = scale.unsqueeze(1)
x = self.norm_final(x) * scale
x = self.linear(x)
return x
class RopeEmbedder:
def __init__(
self,
theta: float = 256.0,
axes_dims: List[int] = (16, 56, 56),
axes_lens: List[int] = (64, 128, 128),
):
self.theta = theta
self.axes_dims = axes_dims
self.axes_lens = axes_lens
assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
self.freqs_cis = None
@staticmethod
def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
with torch.device("cpu"):
freqs_cis = []
for i, (d, e) in enumerate(zip(dim, end)):
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
freqs = torch.outer(timestep, freqs).float()
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
freqs_cis.append(freqs_cis_i)
return freqs_cis
def __call__(self, ids: torch.Tensor):
assert ids.ndim == 2
assert ids.shape[-1] == len(self.axes_dims)
device = ids.device
if self.freqs_cis is None:
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
else:
# Ensure freqs_cis are on the same device as ids
if self.freqs_cis[0].device != device:
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
result = []
for i in range(len(self.axes_dims)):
index = ids[:, i]
result.append(self.freqs_cis[i][index])
return torch.cat(result, dim=-1)
class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["ZImageTransformerBlock"]
_repeated_blocks = ["ZImageTransformerBlock"]
_skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] # precision sensitive layers
@register_to_config
def __init__(
self,
all_patch_size=(2,),
all_f_patch_size=(1,),
in_channels=16,
dim=3840,
n_layers=30,
n_refiner_layers=2,
n_heads=30,
n_kv_heads=30,
norm_eps=1e-5,
qk_norm=True,
cap_feat_dim=2560,
siglip_feat_dim=None, # Optional: set to enable SigLIP support for Omni
rope_theta=256.0,
t_scale=1000.0,
axes_dims=[32, 48, 48],
axes_lens=[1024, 512, 512],
) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.all_patch_size = all_patch_size
self.all_f_patch_size = all_f_patch_size
self.dim = dim
self.n_heads = n_heads
self.rope_theta = rope_theta
self.t_scale = t_scale
self.gradient_checkpointing = False
assert len(all_patch_size) == len(all_f_patch_size)
all_x_embedder = {}
all_final_layer = {}
for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)):
x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True)
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels)
all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer
self.all_x_embedder = nn.ModuleDict(all_x_embedder)
self.all_final_layer = nn.ModuleDict(all_final_layer)
self.noise_refiner = nn.ModuleList(
[
ZImageTransformerBlock(
1000 + layer_id,
dim,
n_heads,
n_kv_heads,
norm_eps,
qk_norm,
modulation=True,
)
for layer_id in range(n_refiner_layers)
]
)
self.context_refiner = nn.ModuleList(
[
ZImageTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
norm_eps,
qk_norm,
modulation=False,
)
for layer_id in range(n_refiner_layers)
]
)
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
# self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True))
self.semantic_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True))
# Optional SigLIP components (for Omni variant)
if siglip_feat_dim is not None:
self.siglip_embedder = nn.Sequential(
RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True)
)
self.siglip_refiner = nn.ModuleList(
[
ZImageTransformerBlock(
2000 + layer_id,
dim,
n_heads,
n_kv_heads,
norm_eps,
qk_norm,
modulation=False,
)
for layer_id in range(n_refiner_layers)
]
)
self.siglip_pad_token = nn.Parameter(torch.empty((1, dim)))
else:
self.siglip_embedder = None
self.siglip_refiner = None
self.siglip_pad_token = None
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
self.layers = nn.ModuleList(
[
ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm)
for layer_id in range(n_layers)
]
)
head_dim = dim // n_heads
assert head_dim == sum(axes_dims)
self.axes_dims = axes_dims
self.axes_lens = axes_lens
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
def unpatchify(
self,
x: List[torch.Tensor],
size: List[Tuple],
patch_size,
f_patch_size,
x_pos_offsets: Optional[List[Tuple[int, int]]] = None,
) -> List[torch.Tensor]:
pH = pW = patch_size
pF = f_patch_size
bsz = len(x)
assert len(size) == bsz
if x_pos_offsets is not None:
# Omni: extract target image from unified sequence (cond_images + target)
result = []
for i in range(bsz):
unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]]
cu_len = 0
x_item = None
for j in range(len(size[i])):
if size[i][j] is None:
ori_len = 0
pad_len = SEQ_MULTI_OF
cu_len += pad_len + ori_len
else:
F, H, W = size[i][j]
ori_len = (F // pF) * (H // pH) * (W // pW)
pad_len = (-ori_len) % SEQ_MULTI_OF
x_item = (
unified_x[cu_len : cu_len + ori_len]
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
.permute(6, 0, 3, 1, 4, 2, 5)
.reshape(self.out_channels, F, H, W)
)
cu_len += ori_len + pad_len
result.append(x_item) # Return only the last (target) image
return result
else:
# Original mode: simple unpatchify
for i in range(bsz):
F, H, W = size[i]
ori_len = (F // pF) * (H // pH) * (W // pW)
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
x[i] = (
x[i][:ori_len]
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
.permute(6, 0, 3, 1, 4, 2, 5)
.reshape(self.out_channels, F, H, W)
)
return x
@staticmethod
def create_coordinate_grid(size, start=None, device=None):
if start is None:
start = (0 for _ in size)
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
grids = torch.meshgrid(axes, indexing="ij")
return torch.stack(grids, dim=-1)
def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int):
"""Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim)."""
pH, pW, pF = patch_size, patch_size, f_patch_size
C, F, H, W = image.size()
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
return image, (F, H, W), (F_tokens, H_tokens, W_tokens)
def _pad_with_ids(
self,
feat: torch.Tensor,
pos_grid_size: Tuple,
pos_start: Tuple,
device: torch.device,
noise_mask_val: Optional[int] = None,
):
"""Pad feature to SEQ_MULTI_OF, create position IDs and pad mask."""
ori_len = len(feat)
pad_len = (-ori_len) % SEQ_MULTI_OF
total_len = ori_len + pad_len
# Pos IDs
ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2)
if pad_len > 0:
pad_pos_ids = (
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
.flatten(0, 2)
.repeat(pad_len, 1)
)
pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0)
padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0)
pad_mask = torch.cat(
[
torch.zeros(ori_len, dtype=torch.bool, device=device),
torch.ones(pad_len, dtype=torch.bool, device=device),
]
)
else:
pos_ids = ori_pos_ids
padded_feat = feat
pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device)
noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level
return padded_feat, pos_ids, pad_mask, total_len, noise_mask
def patchify_and_embed(
self, all_image: List[torch.Tensor], all_cap_feats: List[torch.Tensor], patch_size: int, f_patch_size: int
):
"""Patchify for basic mode: single image per batch item."""
device = all_image[0].device
all_img_out, all_img_size, all_img_pos_ids, all_img_pad_mask = [], [], [], []
all_cap_out, all_cap_pos_ids, all_cap_pad_mask = [], [], []
for image, cap_feat in zip(all_image, all_cap_feats):
# Caption
cap_out, cap_pos_ids, cap_pad_mask, cap_len, _ = self._pad_with_ids(
cap_feat, (len(cap_feat) + (-len(cap_feat)) % SEQ_MULTI_OF, 1, 1), (1, 0, 0), device
)
all_cap_out.append(cap_out)
all_cap_pos_ids.append(cap_pos_ids)
all_cap_pad_mask.append(cap_pad_mask)
# Image
img_patches, size, (F_t, H_t, W_t) = self._patchify_image(image, patch_size, f_patch_size)
img_out, img_pos_ids, img_pad_mask, _, _ = self._pad_with_ids(
img_patches, (F_t, H_t, W_t), (cap_len + 1, 0, 0), device
)
all_img_out.append(img_out)
all_img_size.append(size)
all_img_pos_ids.append(img_pos_ids)
all_img_pad_mask.append(img_pad_mask)
return (
all_img_out,
all_cap_out,
all_img_size,
all_img_pos_ids,
all_cap_pos_ids,
all_img_pad_mask,
all_cap_pad_mask,
)
def patchify_and_embed_omni(
self,
all_x: List[List[torch.Tensor]],
all_cap_feats: List[List[torch.Tensor]],
all_siglip_feats: List[List[torch.Tensor]],
patch_size: int,
f_patch_size: int,
images_noise_mask: List[List[int]],
):
"""Patchify for omni mode: multiple images per batch item with noise masks."""
bsz = len(all_x)
device = all_x[0][-1].device
dtype = all_x[0][-1].dtype
all_x_out, all_x_size, all_x_pos_ids, all_x_pad_mask, all_x_len, all_x_noise_mask = [], [], [], [], [], []
all_cap_out, all_cap_pos_ids, all_cap_pad_mask, all_cap_len, all_cap_noise_mask = [], [], [], [], []
all_sig_out, all_sig_pos_ids, all_sig_pad_mask, all_sig_len, all_sig_noise_mask = [], [], [], [], []
for i in range(bsz):
num_images = len(all_x[i])
cap_feats_list, cap_pos_list, cap_mask_list, cap_lens, cap_noise = [], [], [], [], []
cap_end_pos = []
cap_cu_len = 1
# Process captions
for j, cap_item in enumerate(all_cap_feats[i]):
noise_val = images_noise_mask[i][j] if j < len(images_noise_mask[i]) else 1
cap_out, cap_pos, cap_mask, cap_len, cap_nm = self._pad_with_ids(
cap_item,
(len(cap_item) + (-len(cap_item)) % SEQ_MULTI_OF, 1, 1),
(cap_cu_len, 0, 0),
device,
noise_val,
)
cap_feats_list.append(cap_out)
cap_pos_list.append(cap_pos)
cap_mask_list.append(cap_mask)
cap_lens.append(cap_len)
cap_noise.extend(cap_nm)
cap_cu_len += len(cap_item)
cap_end_pos.append(cap_cu_len)
cap_cu_len += 2 # for image vae and siglip tokens
all_cap_out.append(torch.cat(cap_feats_list, dim=0))
all_cap_pos_ids.append(torch.cat(cap_pos_list, dim=0))
all_cap_pad_mask.append(torch.cat(cap_mask_list, dim=0))
all_cap_len.append(cap_lens)
all_cap_noise_mask.append(cap_noise)
# Process images
x_feats_list, x_pos_list, x_mask_list, x_lens, x_size, x_noise = [], [], [], [], [], []
for j, x_item in enumerate(all_x[i]):
noise_val = images_noise_mask[i][j]
if x_item is not None:
x_patches, size, (F_t, H_t, W_t) = self._patchify_image(x_item, patch_size, f_patch_size)
x_out, x_pos, x_mask, x_len, x_nm = self._pad_with_ids(
x_patches, (F_t, H_t, W_t), (cap_end_pos[j], 0, 0), device, noise_val
)
x_size.append(size)
else:
x_len = SEQ_MULTI_OF
x_out = torch.zeros((x_len, X_PAD_DIM), dtype=dtype, device=device)
x_pos = self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(x_len, 1)
x_mask = torch.ones(x_len, dtype=torch.bool, device=device)
x_nm = [noise_val] * x_len
x_size.append(None)
x_feats_list.append(x_out)
x_pos_list.append(x_pos)
x_mask_list.append(x_mask)
x_lens.append(x_len)
x_noise.extend(x_nm)
all_x_out.append(torch.cat(x_feats_list, dim=0))
all_x_pos_ids.append(torch.cat(x_pos_list, dim=0))
all_x_pad_mask.append(torch.cat(x_mask_list, dim=0))
all_x_size.append(x_size)
all_x_len.append(x_lens)
all_x_noise_mask.append(x_noise)
# Process siglip
if all_siglip_feats[i] is None:
all_sig_len.append([0] * num_images)
all_sig_out.append(None)
else:
sig_feats_list, sig_pos_list, sig_mask_list, sig_lens, sig_noise = [], [], [], [], []
for j, sig_item in enumerate(all_siglip_feats[i]):
noise_val = images_noise_mask[i][j]
if sig_item is not None:
sig_H, sig_W, sig_C = sig_item.size()
sig_flat = sig_item.permute(2, 0, 1).reshape(sig_H * sig_W, sig_C)
sig_out, sig_pos, sig_mask, sig_len, sig_nm = self._pad_with_ids(
sig_flat, (1, sig_H, sig_W), (cap_end_pos[j] + 1, 0, 0), device, noise_val
)
# Scale position IDs to match x resolution
if x_size[j] is not None:
sig_pos = sig_pos.float()
sig_pos[..., 1] = sig_pos[..., 1] / max(sig_H - 1, 1) * (x_size[j][1] - 1)
sig_pos[..., 2] = sig_pos[..., 2] / max(sig_W - 1, 1) * (x_size[j][2] - 1)
sig_pos = sig_pos.to(torch.int32)
else:
sig_len = SEQ_MULTI_OF
sig_out = torch.zeros((sig_len, self.config.siglip_feat_dim), dtype=dtype, device=device)
sig_pos = (
self.create_coordinate_grid((1, 1, 1), (0, 0, 0), device).flatten(0, 2).repeat(sig_len, 1)
)
sig_mask = torch.ones(sig_len, dtype=torch.bool, device=device)
sig_nm = [noise_val] * sig_len
sig_feats_list.append(sig_out)
sig_pos_list.append(sig_pos)
sig_mask_list.append(sig_mask)
sig_lens.append(sig_len)
sig_noise.extend(sig_nm)
all_sig_out.append(torch.cat(sig_feats_list, dim=0))
all_sig_pos_ids.append(torch.cat(sig_pos_list, dim=0))
all_sig_pad_mask.append(torch.cat(sig_mask_list, dim=0))
all_sig_len.append(sig_lens)
all_sig_noise_mask.append(sig_noise)
# Compute x position offsets
all_x_pos_offsets = [(sum(all_cap_len[i]), sum(all_cap_len[i]) + sum(all_x_len[i])) for i in range(bsz)]
return (
all_x_out,
all_cap_out,
all_sig_out,
all_x_size,
all_x_pos_ids,
all_cap_pos_ids,
all_sig_pos_ids,
all_x_pad_mask,
all_cap_pad_mask,
all_sig_pad_mask,
all_x_pos_offsets,
all_x_noise_mask,
all_cap_noise_mask,
all_sig_noise_mask,
)
def _prepare_sequence(
self,
feats: List[torch.Tensor],
pos_ids: List[torch.Tensor],
inner_pad_mask: List[torch.Tensor],
pad_token: torch.nn.Parameter,
noise_mask: Optional[List[List[int]]] = None,
device: torch.device = None,
):
"""Prepare sequence: apply pad token, RoPE embed, pad to batch, create attention mask."""
item_seqlens = [len(f) for f in feats]
max_seqlen = max(item_seqlens)
bsz = len(feats)
# Pad token
feats_cat = torch.cat(feats, dim=0)
feats_cat[torch.cat(inner_pad_mask)] = pad_token
feats = list(feats_cat.split(item_seqlens, dim=0))
# RoPE
freqs_cis = list(self.rope_embedder(torch.cat(pos_ids, dim=0)).split([len(p) for p in pos_ids], dim=0))
# Pad to batch
feats = pad_sequence(feats, batch_first=True, padding_value=0.0)
freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]]
# Attention mask
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(item_seqlens):
attn_mask[i, :seq_len] = 1
# Noise mask
noise_mask_tensor = None
if noise_mask is not None:
noise_mask_tensor = pad_sequence(
[torch.tensor(m, dtype=torch.long, device=device) for m in noise_mask],
batch_first=True,
padding_value=0,
)[:, : feats.shape[1]]
return feats, freqs_cis, attn_mask, item_seqlens, noise_mask_tensor
def _build_unified_sequence(
self,
x: torch.Tensor,
x_freqs: torch.Tensor,
x_seqlens: List[int],
x_noise_mask: Optional[List[List[int]]],
cap: torch.Tensor,
cap_freqs: torch.Tensor,
cap_seqlens: List[int],
cap_noise_mask: Optional[List[List[int]]],
siglip: Optional[torch.Tensor],
siglip_freqs: Optional[torch.Tensor],
siglip_seqlens: Optional[List[int]],
siglip_noise_mask: Optional[List[List[int]]],
omni_mode: bool,
device: torch.device,
):
"""Build unified sequence: x, cap, and optionally siglip.
Basic mode order: [x, cap]; Omni mode order: [cap, x, siglip]
"""
bsz = len(x_seqlens)
unified = []
unified_freqs = []
unified_noise_mask = []
for i in range(bsz):
x_len, cap_len = x_seqlens[i], cap_seqlens[i]
if omni_mode:
# Omni: [cap, x, siglip]
if siglip is not None and siglip_seqlens is not None:
sig_len = siglip_seqlens[i]
unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len], siglip[i][:sig_len]]))
unified_freqs.append(
torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len], siglip_freqs[i][:sig_len]])
)
unified_noise_mask.append(
torch.tensor(
cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], dtype=torch.long, device=device
)
)
else:
unified.append(torch.cat([cap[i][:cap_len], x[i][:x_len]]))
unified_freqs.append(torch.cat([cap_freqs[i][:cap_len], x_freqs[i][:x_len]]))
unified_noise_mask.append(
torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device)
)
else:
# Basic: [x, cap]
unified.append(torch.cat([x[i][:x_len], cap[i][:cap_len]]))
unified_freqs.append(torch.cat([x_freqs[i][:x_len], cap_freqs[i][:cap_len]]))
# Compute unified seqlens
if omni_mode:
if siglip is not None and siglip_seqlens is not None:
unified_seqlens = [a + b + c for a, b, c in zip(cap_seqlens, x_seqlens, siglip_seqlens)]
else:
unified_seqlens = [a + b for a, b in zip(cap_seqlens, x_seqlens)]
else:
unified_seqlens = [a + b for a, b in zip(x_seqlens, cap_seqlens)]
max_seqlen = max(unified_seqlens)
# Pad to batch
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0)
# Attention mask
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_seqlens):
attn_mask[i, :seq_len] = 1
# Noise mask
noise_mask_tensor = None
if omni_mode:
noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0)[
:, : unified.shape[1]
]
return unified, unified_freqs, attn_mask, noise_mask_tensor
def forward(
self,
x: Union[List[torch.Tensor], List[List[torch.Tensor]]],
t,
cap_feats: Union[List[torch.Tensor], List[List[torch.Tensor]]],
return_dict: bool = True,
controlnet_block_samples: Optional[Dict[int, torch.Tensor]] = None,
siglip_feats: Optional[List[List[torch.Tensor]]] = None,
image_noise_mask: Optional[List[List[int]]] = None,
patch_size: int = 2,
f_patch_size: int = 1,
):
"""
Flow: patchify -> t_embed -> x_embed -> x_refine -> cap_embed -> cap_refine
-> [siglip_embed -> siglip_refine] -> build_unified -> main_layers -> final_layer -> unpatchify
"""
assert patch_size in self.all_patch_size and f_patch_size in self.all_f_patch_size
omni_mode = isinstance(x[0], list)
device = x[0][-1].device if omni_mode else x[0].device
if omni_mode:
# Dual embeddings: noisy (t) and clean (t=1)
t_noisy = self.t_embedder(t * self.t_scale).type_as(x[0][-1])
t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale).type_as(x[0][-1])
adaln_input = None
else:
# Single embedding for all tokens
adaln_input = self.t_embedder(t * self.t_scale).type_as(x[0])
t_noisy = t_clean = None
# Patchify
if omni_mode:
(
x,
cap_feats,
siglip_feats,
x_size,
x_pos_ids,
cap_pos_ids,
siglip_pos_ids,
x_pad_mask,
cap_pad_mask,
siglip_pad_mask,
x_pos_offsets,
x_noise_mask,
cap_noise_mask,
siglip_noise_mask,
) = self.patchify_and_embed_omni(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask)
else:
(
x,
cap_feats,
x_size,
x_pos_ids,
cap_pos_ids,
x_pad_mask,
cap_pad_mask,
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
x_pos_offsets = x_noise_mask = cap_noise_mask = siglip_noise_mask = None
# X embed & refine
x_seqlens = [len(xi) for xi in x]
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](torch.cat(x, dim=0)) # embed
x, x_freqs, x_mask, _, x_noise_tensor = self._prepare_sequence(
list(x.split(x_seqlens, dim=0)), x_pos_ids, x_pad_mask, self.x_pad_token, x_noise_mask, device
)
for layer in self.noise_refiner:
x = (
self._gradient_checkpointing_func(
layer, x, x_mask, x_freqs, adaln_input, x_noise_tensor, t_noisy, t_clean
)
if torch.is_grad_enabled() and self.gradient_checkpointing
else layer(x, x_mask, x_freqs, adaln_input, x_noise_tensor, t_noisy, t_clean)
)
# Cap embed & refine
cap_seqlens = [len(ci) for ci in cap_feats]
# cap_feats = self.cap_embedder(torch.cat(cap_feats, dim=0)) # embed
cap_feats = self.semantic_embedder(torch.cat(cap_feats, dim=0))
cap_feats, cap_freqs, cap_mask, _, _ = self._prepare_sequence(
list(cap_feats.split(cap_seqlens, dim=0)), cap_pos_ids, cap_pad_mask, self.cap_pad_token, None, device
)
for layer in self.context_refiner:
cap_feats = (
self._gradient_checkpointing_func(layer, cap_feats, cap_mask, cap_freqs)
if torch.is_grad_enabled() and self.gradient_checkpointing
else layer(cap_feats, cap_mask, cap_freqs)
)
# Siglip embed & refine
siglip_seqlens = siglip_freqs = None
if omni_mode and siglip_feats[0] is not None and self.siglip_embedder is not None:
siglip_seqlens = [len(si) for si in siglip_feats]
siglip_feats = self.siglip_embedder(torch.cat(siglip_feats, dim=0)) # embed
siglip_feats, siglip_freqs, siglip_mask, _, _ = self._prepare_sequence(
list(siglip_feats.split(siglip_seqlens, dim=0)),
siglip_pos_ids,
siglip_pad_mask,
self.siglip_pad_token,
None,
device,
)
for layer in self.siglip_refiner:
siglip_feats = (
self._gradient_checkpointing_func(layer, siglip_feats, siglip_mask, siglip_freqs)
if torch.is_grad_enabled() and self.gradient_checkpointing
else layer(siglip_feats, siglip_mask, siglip_freqs)
)
# Unified sequence
unified, unified_freqs, unified_mask, unified_noise_tensor = self._build_unified_sequence(
x,
x_freqs,
x_seqlens,
x_noise_mask,
cap_feats,
cap_freqs,
cap_seqlens,
cap_noise_mask,
siglip_feats,
siglip_freqs,
siglip_seqlens,
siglip_noise_mask,
omni_mode,
device,
)
# Main transformer layers
for layer_idx, layer in enumerate(self.layers):
unified = (
self._gradient_checkpointing_func(
layer, unified, unified_mask, unified_freqs, adaln_input, unified_noise_tensor, t_noisy, t_clean
)
if torch.is_grad_enabled() and self.gradient_checkpointing
else layer(unified, unified_mask, unified_freqs, adaln_input, unified_noise_tensor, t_noisy, t_clean)
)
if controlnet_block_samples is not None and layer_idx in controlnet_block_samples:
unified = unified + controlnet_block_samples[layer_idx]
unified = (
self.all_final_layer[f"{patch_size}-{f_patch_size}"](
unified, noise_mask=unified_noise_tensor, c_noisy=t_noisy, c_clean=t_clean
)
if omni_mode
else self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, c=adaln_input)
)
# Unpatchify
x = self.unpatchify(list(unified.unbind(dim=0)), x_size, patch_size, f_patch_size, x_pos_offsets)
return (x,) if not return_dict else Transformer2DModelOutput(sample=x)
\ No newline at end of file
"""SigVQ: Semantic token embedding extractor for the image decoder."""
import torch
import torch.nn as nn
class _LinearWrapper(nn.Module):
"""Wraps nn.Linear inside a .proj attribute to match diffusers checkpoint key format."""
def __init__(self, in_features, out_features):
super().__init__()
self.proj = nn.Linear(in_features, out_features)
def forward(self, x):
return self.proj(x)
class _FeedForward(nn.Module):
"""SiLU feed-forward matching diffusers key layout: net.0.proj / net.1 / net.2"""
def __init__(self, dim: int, hidden_dim: int):
super().__init__()
self.net = nn.Sequential(
_LinearWrapper(dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class SigVQ(nn.Module):
"""
Lightweight semantic token extractor.
Maps discrete VQ token IDs to continuous feature vectors via embedding + projection.
Args:
vocab_size: VQ codebook size (default: 16384).
inner_dim: Feature dimension (default: 4096).
"""
def __init__(self, vocab_size: int = 16384, inner_dim: int = 4096):
super().__init__()
self.prior_token_embedding = nn.Embedding(vocab_size, inner_dim)
self.prior_projector = _FeedForward(dim=inner_dim, hidden_dim=inner_dim)
self.requires_grad_(False)
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
"""
Args:
token_ids: (batch, seq_len) discrete token indices.
Returns:
(batch, seq_len, inner_dim) projected feature vectors.
"""
return self.prior_projector(self.prior_token_embedding(token_ids))
"""Smart image resizing with aspect-ratio preservation and factor alignment."""
import math
from typing import List, Tuple
from PIL import Image
def smart_resize(
height: int,
width: int,
min_pixels: int,
max_pixels: int,
factor: int = 32,
) -> Tuple[int, int]:
"""
Qwen2.5-VL style smart resize.
Scales the image to fit within [min_pixels, max_pixels] while preserving
the aspect ratio, and returns target dimensions aligned to ``factor``.
"""
h_bar = max(round(height / factor) * factor, factor)
w_bar = max(round(width / factor) * factor, factor)
if h_bar * w_bar > max_pixels:
scale = math.sqrt(max_pixels / (height * width))
h_bar = max(math.floor(height * scale / factor) * factor, factor)
w_bar = max(math.floor(width * scale / factor) * factor, factor)
elif h_bar * w_bar < min_pixels:
scale = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * scale / factor) * factor
w_bar = math.ceil(width * scale / factor) * factor
return h_bar, w_bar
def resize_and_center_crop(
img: Image.Image,
target_h: int,
target_w: int,
factor: int = 32,
) -> Image.Image:
"""
Resize the image (preserving aspect ratio) so that it covers the target
dimensions, then center-crop to a factor-aligned size.
"""
width, height = img.size
# Scale so the image just covers the target area
scale = max(target_h / height, target_w / width)
new_h = int(round(height * scale))
new_w = int(round(width * scale))
img = img.resize((new_w, new_h), resample=Image.BICUBIC)
# Center-crop to factor-aligned dimensions
crop_h = (new_h // factor) * factor
crop_w = (new_w // factor) * factor
# Ensure at least target size
crop_h = max(crop_h, target_h)
crop_w = max(crop_w, target_w)
top = (new_h - crop_h) // 2
left = (new_w - crop_w) // 2
img = img.crop((left, top, left + crop_w, top + crop_h))
return img
def smart_resize_images(
image_paths: List[str],
patch_size: int = 16,
merge_size: int = 2,
single_min_pixels: int = 128 * 128,
single_max_pixels: int = 800 * 800,
multi_min_pixels: int = 128 * 128,
multi_max_pixels: int = 448 * 448,
) -> List[Image.Image]:
"""
Smart-resize a list of images for model input.
Uses larger resolution limits for single-image inputs and smaller limits
for multi-image inputs to control total token count.
"""
num_images = len(image_paths)
if num_images == 0:
return []
factor = patch_size * merge_size # 32
if num_images == 1:
min_pixels = single_min_pixels
max_pixels = single_max_pixels
else:
min_pixels = multi_min_pixels
max_pixels = multi_max_pixels
images = []
for path in image_paths:
if path is None:
images.append(path)
continue
img = Image.open(path).convert("RGB")
width, height = img.size
target_h, target_w = smart_resize(height, width, min_pixels, max_pixels, factor)
img = resize_and_center_crop(img, target_h, target_w, factor)
images.append(img)
return images
from .transport import ModelType, PathType, Sampler, Transport, WeightType
def create_transport(
path_type="Linear",
prediction="velocity",
loss_weight=None,
train_eps=None,
sample_eps=None,
snr_type="uniform",
do_shift=True,
seq_len=1024, # corresponding to 512x512
):
"""function for creating Transport object
**Note**: model prediction defaults to velocity
Args:
- path_type: type of path to use; default to linear
- learn_score: set model prediction to score
- learn_noise: set model prediction to noise
- velocity_weighted: weight loss by velocity weight
- likelihood_weighted: weight loss by likelihood weight
- train_eps: small epsilon for avoiding instability during training
- sample_eps: small epsilon for avoiding instability during sampling
"""
if prediction == "noise":
model_type = ModelType.NOISE
elif prediction == "score":
model_type = ModelType.SCORE
else:
model_type = ModelType.VELOCITY
if loss_weight == "velocity":
loss_type = WeightType.VELOCITY
elif loss_weight == "likelihood":
loss_type = WeightType.LIKELIHOOD
else:
loss_type = WeightType.NONE
path_choice = {
"Linear": PathType.LINEAR,
"GVP": PathType.GVP,
"VP": PathType.VP,
}
path_type = path_choice[path_type]
if path_type in [PathType.VP]:
train_eps = 1e-5 if train_eps is None else train_eps
sample_eps = 1e-3 if train_eps is None else sample_eps
elif path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY:
train_eps = 1e-3 if train_eps is None else train_eps
sample_eps = 1e-3 if train_eps is None else sample_eps
else: # velocity & [GVP, LINEAR] is stable everywhere
train_eps = 0
sample_eps = 0
# create flow state
state = Transport(
model_type=model_type,
path_type=path_type,
loss_type=loss_type,
train_eps=train_eps,
sample_eps=sample_eps,
snr_type=snr_type,
do_shift=do_shift,
seq_len=seq_len,
)
return state
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
# This file is modified from https://github.com/PixArt-alpha/PixArt-sigma
import os
import torch
from tqdm import tqdm
class NoiseScheduleFlow:
def __init__(
self,
schedule="discrete_flow",
):
"""Create a wrapper class for the forward SDE (EDM type)."""
self.T = 1
self.t0 = 0.001
self.schedule = schedule # ['continuous', 'discrete_flow']
self.total_N = 1000
def marginal_log_mean_coeff(self, t):
"""
Compute log(alpha_t) of a given continuous-time label t in [0, T].
"""
return torch.log(self.marginal_alpha(t))
def marginal_alpha(self, t):
"""
Compute alpha_t of a given continuous-time label t in [0, T].
"""
return 1 - t
@staticmethod
def marginal_std(t):
"""
Compute sigma_t of a given continuous-time label t in [0, T].
"""
return t
def marginal_lambda(self, t):
"""
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
"""
log_mean_coeff = self.marginal_log_mean_coeff(t)
log_std = torch.log(self.marginal_std(t))
return log_mean_coeff - log_std
@staticmethod
def inverse_lambda(lamb):
"""
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
"""
return torch.exp(-lamb)
def model_wrapper(
model,
noise_schedule,
model_type="noise",
model_kwargs={},
guidance_type="uncond",
condition=None,
unconditional_condition=None,
guidance_scale=1.0,
interval_guidance=[0, 1.0],
classifier_fn=None,
classifier_kwargs={},
):
"""Create a wrapper function for the noise prediction model.
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
We support four types of the diffusion model by setting `model_type`:
1. "noise": noise prediction model. (Trained by predicting noise).
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
3. "v": velocity prediction model. (Trained by predicting the velocity).
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
arXiv preprint arXiv:2202.00512 (2022).
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
arXiv preprint arXiv:2210.02303 (2022).
4. "score": marginal score function. (Trained by denoising score matching).
Note that the score function and the noise prediction model follows a simple relationship:
```
noise(x_t, t) = -sigma_t * score(x_t, t)
```
We support three types of guided sampling by DPMs by setting `guidance_type`:
1. "uncond": unconditional sampling by DPMs.
The input `model` has the following format:
``
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
``
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
The input `model` has the following format:
``
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
``
The input `classifier_fn` has the following format:
``
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
``
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
The input `model` has the following format:
``
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
``
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
arXiv preprint arXiv:2207.12598 (2022).
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
or continuous-time labels (i.e. epsilon to T).
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
``
def model_fn(x, t_continuous) -> noise:
t_input = get_model_input_time(t_continuous)
return noise_pred(model, x, t_input, **model_kwargs)
``
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
===============================================================
Args:
model: A diffusion model with the corresponding format described above.
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
model_type: A `str`. The parameterization type of the diffusion model.
"noise" or "x_start" or "v" or "score".
model_kwargs: A `dict`. A dict for the other inputs of the model function.
guidance_type: A `str`. The type of the guidance for sampling.
"uncond" or "classifier" or "classifier-free".
condition: A pytorch tensor. The condition for the guided sampling.
Only used for "classifier" or "classifier-free" guidance type.
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
Only used for "classifier-free" guidance type.
guidance_scale: A `float`. The scale for the guided sampling.
classifier_fn: A classifier function. Only used for the classifier guidance.
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
Returns:
A noise prediction model that accepts the noised data and the continuous time as the inputs.
"""
def get_model_input_time(t_continuous):
"""
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
For continuous-time DPMs, we just use `t_continuous`.
"""
if noise_schedule.schedule == "discrete":
return (t_continuous - 1.0 / noise_schedule.total_N) * noise_schedule.total_N
elif noise_schedule.schedule == "discrete_flow":
return t_continuous * noise_schedule.total_N
else:
return t_continuous
def noise_pred_fn(x, t_continuous, cond=None):
t_input = get_model_input_time(t_continuous)
if cond is None:
output = model(x, t_input, **model_kwargs)
else:
output = model(x, t_input, cond, **model_kwargs)
if model_type == "noise":
return output
elif model_type == "x_start":
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim())
elif model_type == "v":
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x
elif model_type == "score":
sigma_t = noise_schedule.marginal_std(t_continuous)
return -expand_dims(sigma_t, x.dim()) * output
elif model_type == "flow":
_, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
try:
noise = (1 - expand_dims(sigma_t, x.dim()).to(x)) * output + x
except:
noise = (1 - expand_dims(sigma_t, x.dim()).to(x)) * output[0] + x
return noise
def cond_grad_fn(x, t_input):
"""
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
"""
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
return torch.autograd.grad(log_prob.sum(), x_in)[0]
def model_fn(x, t_continuous):
"""
The noise predicition model function that is used for DPM-Solver.
"""
guidance_tp = guidance_type
if guidance_tp == "uncond":
return noise_pred_fn(x, t_continuous)
elif guidance_tp == "classifier":
assert classifier_fn is not None
t_input = get_model_input_time(t_continuous)
cond_grad = cond_grad_fn(x, t_input)
sigma_t = noise_schedule.marginal_std(t_continuous)
noise = noise_pred_fn(x, t_continuous)
return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad
elif guidance_tp == "classifier-free":
if (
guidance_scale == 1.0
or unconditional_condition is None
or not (interval_guidance[0] < t_continuous[0] < interval_guidance[1])
):
return noise_pred_fn(x, t_continuous, cond=condition)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t_continuous] * 2)
c_in = torch.cat([unconditional_condition, condition])
try:
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
except:
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in)[0].chunk(2)
return noise_uncond + guidance_scale * (noise - noise_uncond)
assert model_type in ["noise", "x_start", "v", "score", "flow"]
assert guidance_type in [
"uncond",
"classifier",
"classifier-free",
]
return model_fn
class DPM_Solver:
def __init__(
self,
model_fn,
noise_schedule,
algorithm_type="dpmsolver++",
correcting_x0_fn=None,
correcting_xt_fn=None,
thresholding_max_val=1.0,
dynamic_thresholding_ratio=0.995,
):
"""Construct a DPM-Solver.
We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`).
We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you
can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the
dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space
DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space
DPMs (such as stable-diffusion).
To support advanced algorithms in image-to-image applications, we also support corrector functions for
both x0 and xt.
Args:
model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
``
def model_fn(x, t_continuous):
return noise
``
The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`.
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++".
correcting_x0_fn: A `str` or a function with the following format:
```
def correcting_x0_fn(x0, t):
x0_new = ...
return x0_new
```
This function is to correct the outputs of the data prediction model at each sampling step. e.g.,
```
x0_pred = data_pred_model(xt, t)
if correcting_x0_fn is not None:
x0_pred = correcting_x0_fn(x0_pred, t)
xt_1 = update(x0_pred, xt, t)
```
If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1].
correcting_xt_fn: A function with the following format:
```
def correcting_xt_fn(xt, t, step):
x_new = ...
return x_new
```
This function is to correct the intermediate samples xt at each sampling step. e.g.,
```
xt = ...
xt = correcting_xt_fn(xt, t, step)
```
thresholding_max_val: A `float`. The max value for thresholding.
Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details).
Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
[1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour,
Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models
with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
"""
self.model = lambda x, t: model_fn(x, t.expand(x.shape[0]))
self.noise_schedule = noise_schedule
assert algorithm_type in ["dpmsolver", "dpmsolver++"]
self.algorithm_type = algorithm_type
if correcting_x0_fn == "dynamic_thresholding":
self.correcting_x0_fn = self.dynamic_thresholding_fn
else:
self.correcting_x0_fn = correcting_x0_fn
self.correcting_xt_fn = correcting_xt_fn
self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
self.thresholding_max_val = thresholding_max_val
self.register_progress_bar()
def register_progress_bar(self, progress_fn=None):
"""
Register a progress bar callback function
Args:
progress_fn: Callback function that takes current step and total steps as parameters
"""
self.progress_fn = progress_fn if progress_fn is not None else lambda step, total: None
def update_progress(self, step, total_steps):
"""
Update sampling progress
Args:
step: Current step number
total_steps: Total number of steps
"""
if hasattr(self, "progress_fn"):
try:
self.progress_fn(step / total_steps, desc=f"Generating {step}/{total_steps}")
except:
self.progress_fn(step, total_steps)
else:
# If no progress_fn registered, use default empty function
pass
def dynamic_thresholding_fn(self, x0, t):
"""
The dynamic thresholding method.
"""
dims = x0.dim()
p = self.dynamic_thresholding_ratio
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
x0 = torch.clamp(x0, -s, s) / s
return x0
def noise_prediction_fn(self, x, t):
"""
Return the noise prediction model.
"""
return self.model(x, t)
def data_prediction_fn(self, x, t):
"""
Return the data prediction model (with corrector).
"""
noise = self.noise_prediction_fn(x, t)
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
x0 = (x - sigma_t * noise) / alpha_t
if self.correcting_x0_fn is not None:
x0 = self.correcting_x0_fn(x0, t)
return x0
def model_fn(self, x, t):
"""
Convert the model to the noise prediction model or the data prediction model.
"""
if self.algorithm_type == "dpmsolver++":
return self.data_prediction_fn(x, t)
else:
return self.noise_prediction_fn(x, t)
def get_time_steps(self, skip_type, t_T, t_0, N, device, shift=1.0):
"""Compute the intermediate time steps for sampling.
Args:
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
- 'logSNR': uniform logSNR for the time steps.
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
t_T: A `float`. The starting time of the sampling (default is T).
t_0: A `float`. The ending time of the sampling (default is epsilon).
N: A `int`. The total number of the spacing of the time steps.
device: A torch device.
Returns:
A pytorch tensor of the time steps, with the shape (N + 1,).
"""
if skip_type == "logSNR":
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
return self.noise_schedule.inverse_lambda(logSNR_steps)
elif skip_type == "time_uniform":
return torch.linspace(t_T, t_0, N + 1).to(device)
elif skip_type == "time_quadratic":
t_order = 2
t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device)
return t
elif skip_type == "time_uniform_flow":
betas = torch.linspace(t_T, t_0, N + 1).to(device)
sigmas = 1.0 - betas
sigmas = (shift * sigmas / (1 + (shift - 1) * sigmas)).flip(dims=[0])
return sigmas
else:
raise ValueError(
f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'"
)
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
"""
Get the order of each step for sampling by the singlestep DPM-Solver.
We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
- If order == 1:
We take `steps` of DPM-Solver-1 (i.e. DDIM).
- If order == 2:
- Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
- If steps % 2 == 0, we use K steps of DPM-Solver-2.
- If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
- If order == 3:
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
- If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
- If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
- If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
============================================
Args:
order: A `int`. The max order for the solver (2 or 3).
steps: A `int`. The total number of function evaluations (NFE).
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
- 'logSNR': uniform logSNR for the time steps.
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
t_T: A `float`. The starting time of the sampling (default is T).
t_0: A `float`. The ending time of the sampling (default is epsilon).
device: A torch device.
Returns:
orders: A list of the solver order of each step.
"""
if order == 3:
K = steps // 3 + 1
if steps % 3 == 0:
orders = [3,] * (
K - 2
) + [2, 1]
elif steps % 3 == 1:
orders = [3,] * (
K - 1
) + [1]
else:
orders = [3,] * (
K - 1
) + [2]
elif order == 2:
if steps % 2 == 0:
K = steps // 2
orders = [
2,
] * K
else:
K = steps // 2 + 1
orders = [2,] * (
K - 1
) + [1]
elif order == 1:
K = 1
orders = [
1,
] * steps
else:
raise ValueError("'order' must be '1' or '2' or '3'.")
if skip_type == "logSNR":
# To reproduce the results in DPM-Solver paper
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
else:
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
torch.cumsum(
torch.tensor(
[
0,
]
+ orders
),
0,
).to(device)
]
return timesteps_outer, orders
def denoise_to_zero_fn(self, x, s):
"""
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
"""
return self.data_prediction_fn(x, s)
def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
"""
DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (1,).
t: A pytorch tensor. The ending time, with the shape (1,).
model_s: A pytorch tensor. The model function evaluated at time `s`.
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
return_intermediate: A `bool`. If true, also return the model value at time `s`.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
ns = self.noise_schedule
dims = x.dim()
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
h = lambda_t - lambda_s
log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
if self.algorithm_type == "dpmsolver++":
phi_1 = torch.expm1(-h)
if model_s is None:
model_s = self.model_fn(x, s)
x_t = sigma_t / sigma_s * x - alpha_t * phi_1 * model_s
if return_intermediate:
return x_t, {"model_s": model_s}
else:
return x_t
else:
phi_1 = torch.expm1(h)
if model_s is None:
model_s = self.model_fn(x, s)
x_t = torch.exp(log_alpha_t - log_alpha_s) * x - (sigma_t * phi_1) * model_s
if return_intermediate:
return x_t, {"model_s": model_s}
else:
return x_t
def singlestep_dpm_solver_second_update(
self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type="dpmsolver"
):
"""
Singlestep solver DPM-Solver-2 from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (1,).
t: A pytorch tensor. The ending time, with the shape (1,).
r1: A `float`. The hyperparameter of the second-order solver.
model_s: A pytorch tensor. The model function evaluated at time `s`.
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ["dpmsolver", "taylor"]:
raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
if r1 is None:
r1 = 0.5
ns = self.noise_schedule
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
h = lambda_t - lambda_s
lambda_s1 = lambda_s + r1 * h
s1 = ns.inverse_lambda(lambda_s1)
log_alpha_s, log_alpha_s1, log_alpha_t = (
ns.marginal_log_mean_coeff(s),
ns.marginal_log_mean_coeff(s1),
ns.marginal_log_mean_coeff(t),
)
sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
if self.algorithm_type == "dpmsolver++":
phi_11 = torch.expm1(-r1 * h)
phi_1 = torch.expm1(-h)
if model_s is None:
model_s = self.model_fn(x, s)
x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s
model_s1 = self.model_fn(x_s1, s1)
if solver_type == "dpmsolver":
x_t = (
(sigma_t / sigma_s) * x
- (alpha_t * phi_1) * model_s
- (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s)
)
elif solver_type == "taylor":
x_t = (
(sigma_t / sigma_s) * x
- (alpha_t * phi_1) * model_s
+ (1.0 / r1) * (alpha_t * (phi_1 / h + 1.0)) * (model_s1 - model_s)
)
else:
phi_11 = torch.expm1(r1 * h)
phi_1 = torch.expm1(h)
if model_s is None:
model_s = self.model_fn(x, s)
x_s1 = torch.exp(log_alpha_s1 - log_alpha_s) * x - (sigma_s1 * phi_11) * model_s
model_s1 = self.model_fn(x_s1, s1)
if solver_type == "dpmsolver":
x_t = (
torch.exp(log_alpha_t - log_alpha_s) * x
- (sigma_t * phi_1) * model_s
- (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s)
)
elif solver_type == "taylor":
x_t = (
torch.exp(log_alpha_t - log_alpha_s) * x
- (sigma_t * phi_1) * model_s
- (1.0 / r1) * (sigma_t * (phi_1 / h - 1.0)) * (model_s1 - model_s)
)
if return_intermediate:
return x_t, {"model_s": model_s, "model_s1": model_s1}
else:
return x_t
def singlestep_dpm_solver_third_update(
self,
x,
s,
t,
r1=1.0 / 3.0,
r2=2.0 / 3.0,
model_s=None,
model_s1=None,
return_intermediate=False,
solver_type="dpmsolver",
):
"""
Singlestep solver DPM-Solver-3 from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (1,).
t: A pytorch tensor. The ending time, with the shape (1,).
r1: A `float`. The hyperparameter of the third-order solver.
r2: A `float`. The hyperparameter of the third-order solver.
model_s: A pytorch tensor. The model function evaluated at time `s`.
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ["dpmsolver", "taylor"]:
raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
if r1 is None:
r1 = 1.0 / 3.0
if r2 is None:
r2 = 2.0 / 3.0
ns = self.noise_schedule
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
h = lambda_t - lambda_s
lambda_s1 = lambda_s + r1 * h
lambda_s2 = lambda_s + r2 * h
s1 = ns.inverse_lambda(lambda_s1)
s2 = ns.inverse_lambda(lambda_s2)
log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = (
ns.marginal_log_mean_coeff(s),
ns.marginal_log_mean_coeff(s1),
ns.marginal_log_mean_coeff(s2),
ns.marginal_log_mean_coeff(t),
)
sigma_s, sigma_s1, sigma_s2, sigma_t = (
ns.marginal_std(s),
ns.marginal_std(s1),
ns.marginal_std(s2),
ns.marginal_std(t),
)
alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
if self.algorithm_type == "dpmsolver++":
phi_11 = torch.expm1(-r1 * h)
phi_12 = torch.expm1(-r2 * h)
phi_1 = torch.expm1(-h)
phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0
phi_2 = phi_1 / h + 1.0
phi_3 = phi_2 / h - 0.5
if model_s is None:
model_s = self.model_fn(x, s)
if model_s1 is None:
x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s
model_s1 = self.model_fn(x_s1, s1)
x_s2 = (
(sigma_s2 / sigma_s) * x
- (alpha_s2 * phi_12) * model_s
+ r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s)
)
model_s2 = self.model_fn(x_s2, s2)
if solver_type == "dpmsolver":
x_t = (
(sigma_t / sigma_s) * x
- (alpha_t * phi_1) * model_s
+ (1.0 / r2) * (alpha_t * phi_2) * (model_s2 - model_s)
)
elif solver_type == "taylor":
D1_0 = (1.0 / r1) * (model_s1 - model_s)
D1_1 = (1.0 / r2) * (model_s2 - model_s)
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
x_t = (
(sigma_t / sigma_s) * x
- (alpha_t * phi_1) * model_s
+ (alpha_t * phi_2) * D1
- (alpha_t * phi_3) * D2
)
else:
phi_11 = torch.expm1(r1 * h)
phi_12 = torch.expm1(r2 * h)
phi_1 = torch.expm1(h)
phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0
phi_2 = phi_1 / h - 1.0
phi_3 = phi_2 / h - 0.5
if model_s is None:
model_s = self.model_fn(x, s)
if model_s1 is None:
x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (sigma_s1 * phi_11) * model_s
model_s1 = self.model_fn(x_s1, s1)
x_s2 = (
(torch.exp(log_alpha_s2 - log_alpha_s)) * x
- (sigma_s2 * phi_12) * model_s
- r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s)
)
model_s2 = self.model_fn(x_s2, s2)
if solver_type == "dpmsolver":
x_t = (
(torch.exp(log_alpha_t - log_alpha_s)) * x
- (sigma_t * phi_1) * model_s
- (1.0 / r2) * (sigma_t * phi_2) * (model_s2 - model_s)
)
elif solver_type == "taylor":
D1_0 = (1.0 / r1) * (model_s1 - model_s)
D1_1 = (1.0 / r2) * (model_s2 - model_s)
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
x_t = (
(torch.exp(log_alpha_t - log_alpha_s)) * x
- (sigma_t * phi_1) * model_s
- (sigma_t * phi_2) * D1
- (sigma_t * phi_3) * D2
)
if return_intermediate:
return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2}
else:
return x_t
def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
"""
Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
model_prev_list: A list of pytorch tensor. The previous computed model values.
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
t: A pytorch tensor. The ending time, with the shape (1,).
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ["dpmsolver", "taylor"]:
raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
ns = self.noise_schedule
model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
lambda_prev_1, lambda_prev_0, lambda_t = (
ns.marginal_lambda(t_prev_1),
ns.marginal_lambda(t_prev_0),
ns.marginal_lambda(t),
)
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
h_0 = lambda_prev_0 - lambda_prev_1
h = lambda_t - lambda_prev_0
r0 = h_0 / h
D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1)
if self.algorithm_type == "dpmsolver++":
phi_1 = torch.expm1(-h)
if solver_type == "dpmsolver":
x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0
elif solver_type == "taylor":
x_t = (
(sigma_t / sigma_prev_0) * x
- (alpha_t * phi_1) * model_prev_0
+ (alpha_t * (phi_1 / h + 1.0)) * D1_0
)
else:
phi_1 = torch.expm1(h)
if solver_type == "dpmsolver":
x_t = (
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
- (sigma_t * phi_1) * model_prev_0
- 0.5 * (sigma_t * phi_1) * D1_0
)
elif solver_type == "taylor":
x_t = (
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
- (sigma_t * phi_1) * model_prev_0
- (sigma_t * (phi_1 / h - 1.0)) * D1_0
)
return x_t
def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
"""
Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
model_prev_list: A list of pytorch tensor. The previous computed model values.
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
t: A pytorch tensor. The ending time, with the shape (1,).
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
ns = self.noise_schedule
model_prev_2, model_prev_1, model_prev_0 = model_prev_list
t_prev_2, t_prev_1, t_prev_0 = t_prev_list
lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = (
ns.marginal_lambda(t_prev_2),
ns.marginal_lambda(t_prev_1),
ns.marginal_lambda(t_prev_0),
ns.marginal_lambda(t),
)
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
h_1 = lambda_prev_1 - lambda_prev_2
h_0 = lambda_prev_0 - lambda_prev_1
h = lambda_t - lambda_prev_0
r0, r1 = h_0 / h, h_1 / h
D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1)
D1_1 = (1.0 / r1) * (model_prev_1 - model_prev_2)
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
if self.algorithm_type == "dpmsolver++":
phi_1 = torch.expm1(-h)
phi_2 = phi_1 / h + 1.0
phi_3 = phi_2 / h - 0.5
x_t = (
(sigma_t / sigma_prev_0) * x
- (alpha_t * phi_1) * model_prev_0
+ (alpha_t * phi_2) * D1
- (alpha_t * phi_3) * D2
)
else:
phi_1 = torch.expm1(h)
phi_2 = phi_1 / h - 1.0
phi_3 = phi_2 / h - 0.5
x_t = (
(torch.exp(log_alpha_t - log_alpha_prev_0)) * x
- (sigma_t * phi_1) * model_prev_0
- (sigma_t * phi_2) * D1
- (sigma_t * phi_3) * D2
)
return x_t
def singlestep_dpm_solver_update(
self, x, s, t, order, return_intermediate=False, solver_type="dpmsolver", r1=None, r2=None
):
"""
Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (1,).
t: A pytorch tensor. The ending time, with the shape (1,).
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
r1: A `float`. The hyperparameter of the second-order or third-order solver.
r2: A `float`. The hyperparameter of the third-order solver.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if order == 1:
return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
elif order == 2:
return self.singlestep_dpm_solver_second_update(
x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1
)
elif order == 3:
return self.singlestep_dpm_solver_third_update(
x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2
)
else:
raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}")
def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"):
"""
Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
model_prev_list: A list of pytorch tensor. The previous computed model values.
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
t: A pytorch tensor. The ending time, with the shape (1,).
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if order == 1:
return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
elif order == 2:
return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
elif order == 3:
return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
else:
raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}")
def dpm_solver_adaptive(
self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type="dpmsolver"
):
"""
The adaptive step size solver based on singlestep DPM-Solver.
Args:
x: A pytorch tensor. The initial value at time `t_T`.
order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
t_T: A `float`. The starting time of the sampling (default is T).
t_0: A `float`. The ending time of the sampling (default is epsilon).
h_init: A `float`. The initial step size (for logSNR).
atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
current time and `t_0` is less than `t_err`. The default setting is 1e-5.
solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
Returns:
x_0: A pytorch tensor. The approximated solution at time `t_0`.
[1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
"""
ns = self.noise_schedule
s = t_T * torch.ones((1,)).to(x)
lambda_s = ns.marginal_lambda(s)
lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
h = h_init * torch.ones_like(s).to(x)
x_prev = x
nfe = 0
if order == 2:
r1 = 0.5
lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(
x, s, t, r1=r1, solver_type=solver_type, **kwargs
)
elif order == 3:
r1, r2 = 1.0 / 3.0, 2.0 / 3.0
lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(
x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type
)
higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(
x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs
)
else:
raise ValueError(f"For adaptive step size solver, order must be 2 or 3, got {order}")
while torch.abs(s - t_0).mean() > t_err:
t = ns.inverse_lambda(lambda_s + h)
x_lower, lower_noise_kwargs = lower_update(x, s, t)
x_higher = higher_update(x, s, t, **lower_noise_kwargs)
delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
E = norm_fn((x_higher - x_lower) / delta).max()
if torch.all(E <= 1.0):
x = x_higher
s = t
x_prev = x_lower
lambda_s = ns.marginal_lambda(s)
h = torch.min(theta * h * torch.float_power(E, -1.0 / order).float(), lambda_0 - lambda_s)
nfe += order
print("adaptive solver nfe", nfe)
return x
def add_noise(self, x, t, noise=None):
"""
Compute the noised input xt = alpha_t * x + sigma_t * noise.
Args:
x: A `torch.Tensor` with shape `(batch_size, *shape)`.
t: A `torch.Tensor` with shape `(t_size,)`.
Returns:
xt with shape `(t_size, batch_size, *shape)`.
"""
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
if noise is None:
noise = torch.randn((t.shape[0], *x.shape), device=x.device)
x = x.reshape((-1, *x.shape))
xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise
if t.shape[0] == 1:
return xt.squeeze(0)
else:
return xt
def inverse(
self,
x,
steps=20,
t_start=None,
t_end=None,
order=2,
skip_type="time_uniform",
method="multistep",
lower_order_final=True,
denoise_to_zero=False,
solver_type="dpmsolver",
atol=0.0078,
rtol=0.05,
return_intermediate=False,
):
"""
Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
"""
t_0 = 1.0 / self.noise_schedule.total_N if t_start is None else t_start
t_T = self.noise_schedule.T if t_end is None else t_end
assert (
t_0 > 0 and t_T > 0
), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
return self.sample(
x,
steps=steps,
t_start=t_0,
t_end=t_T,
order=order,
skip_type=skip_type,
method=method,
lower_order_final=lower_order_final,
denoise_to_zero=denoise_to_zero,
solver_type=solver_type,
atol=atol,
rtol=rtol,
return_intermediate=return_intermediate,
)
def sample(
self,
x,
steps=20,
t_start=None,
t_end=None,
order=2,
skip_type="time_uniform",
method="multistep",
lower_order_final=True,
denoise_to_zero=False,
solver_type="dpmsolver",
atol=0.0078,
rtol=0.05,
return_intermediate=False,
flow_shift=1.0,
):
"""
Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
=====================================================
We support the following algorithms for both noise prediction model and data prediction model:
- 'singlestep':
Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
The total number of function evaluations (NFE) == `steps`.
Given a fixed NFE == `steps`, the sampling procedure is:
- If `order` == 1:
- Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
- If `order` == 2:
- Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
- If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
- If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
- If `order` == 3:
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
- If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
- If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
- If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
- 'multistep':
Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
We initialize the first `order` values by lower order multistep solvers.
Given a fixed NFE == `steps`, the sampling procedure is:
Denote K = steps.
- If `order` == 1:
- We use K steps of DPM-Solver-1 (i.e. DDIM).
- If `order` == 2:
- We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
- If `order` == 3:
- We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
- 'singlestep_fixed':
Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
- 'adaptive':
Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
(NFE) and the sample quality.
- If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
- If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
=====================================================
Some advices for choosing the algorithm:
- For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`.
e.g., DPM-Solver:
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
skip_type='time_uniform', method='singlestep')
e.g., DPM-Solver++:
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
skip_type='time_uniform', method='singlestep')
- For **guided sampling with large guidance scale** by DPMs:
Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`.
e.g.
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
skip_type='time_uniform', method='multistep')
We support three types of `skip_type`:
- 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
- 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
- 'time_quadratic': quadratic time for the time steps.
=====================================================
Args:
x: A pytorch tensor. The initial value at time `t_start`
e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
steps: A `int`. The total number of function evaluations (NFE).
t_start: A `float`. The starting time of the sampling.
If `T` is None, we use self.noise_schedule.T (default is 1.0).
t_end: A `float`. The ending time of the sampling.
If `t_end` is None, we use 1. / self.noise_schedule.total_N.
e.g. if total_N == 1000, we have `t_end` == 1e-3.
For discrete-time DPMs:
- We recommend `t_end` == 1. / self.noise_schedule.total_N.
For continuous-time DPMs:
- We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
order: A `int`. The order of DPM-Solver.
skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
for diffusion models sampling by diffusion SDEs for low-resolutional images
(such as CIFAR-10). However, we observed that such trick does not matter for
high-resolutional images. As it needs an additional NFE, we do not recommend
it for high-resolutional images.
lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
Only valid for `method=multistep` and `steps < 15`. We empirically find that
this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
(especially for steps <= 10). So we recommend to set it to be `True`.
solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`.
atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
return_intermediate: A `bool`. Whether to save the xt at each step.
When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0.
Returns:
x_end: A pytorch tensor. The approximated solution at time `t_end`.
"""
t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
t_T = self.noise_schedule.T if t_start is None else t_start
assert (
t_0 > 0 and t_T > 0
), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
if return_intermediate:
assert method in [
"multistep",
"singlestep",
"singlestep_fixed",
], "Cannot use adaptive solver when saving intermediate values"
if self.correcting_xt_fn is not None:
assert method in [
"multistep",
"singlestep",
"singlestep_fixed",
], "Cannot use adaptive solver when correcting_xt_fn is not None"
device = x.device
intermediates = []
with torch.no_grad():
if method == "adaptive":
x = self.dpm_solver_adaptive(
x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type
)
elif method == "multistep":
assert steps >= order
timesteps = self.get_time_steps(
skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device, shift=flow_shift
)
assert timesteps.shape[0] - 1 == steps
# Init the initial values.
step = 0
t = timesteps[step]
t_prev_list = [t]
model_prev_list = [self.model_fn(x, t)]
if self.correcting_xt_fn is not None:
x = self.correcting_xt_fn(x, t, step)
if return_intermediate:
intermediates.append(x)
self.update_progress(step + 1, len(timesteps))
# Init the first `order` values by lower order multistep DPM-Solver.
for step in range(1, order):
t = timesteps[step]
x = self.multistep_dpm_solver_update(
x, model_prev_list, t_prev_list, t, step, solver_type=solver_type
)
if self.correcting_xt_fn is not None:
x = self.correcting_xt_fn(x, t, step)
if return_intermediate:
intermediates.append(x)
t_prev_list.append(t)
model_prev_list.append(self.model_fn(x, t))
# update progress bar
self.update_progress(step + 1, len(timesteps))
# Compute the remaining values by `order`-th order multistep DPM-Solver.
for step in tqdm(range(order, steps + 1), disable=os.getenv("DPM_TQDM", "False") == "True"):
t = timesteps[step]
# We only use lower order for steps < 10
# if lower_order_final and steps < 10:
if lower_order_final: # recommended by Shuchen Xue
step_order = min(order, steps + 1 - step)
else:
step_order = order
x = self.multistep_dpm_solver_update(
x, model_prev_list, t_prev_list, t, step_order, solver_type=solver_type
)
if self.correcting_xt_fn is not None:
x = self.correcting_xt_fn(x, t, step)
if return_intermediate:
intermediates.append(x)
for i in range(order - 1):
t_prev_list[i] = t_prev_list[i + 1]
model_prev_list[i] = model_prev_list[i + 1]
t_prev_list[-1] = t
# We do not need to evaluate the final model value.
if step < steps:
model_prev_list[-1] = self.model_fn(x, t)
# update progress bar
self.update_progress(step + 1, len(timesteps))
elif method in ["singlestep", "singlestep_fixed"]:
if method == "singlestep":
timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(
steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device
)
elif method == "singlestep_fixed":
K = steps // order
orders = [
order,
] * K
timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
for step, order in enumerate(orders):
s, t = timesteps_outer[step], timesteps_outer[step + 1]
timesteps_inner = self.get_time_steps(
skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, device=device
)
lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
h = lambda_inner[-1] - lambda_inner[0]
r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2)
if self.correcting_xt_fn is not None:
x = self.correcting_xt_fn(x, t, step)
if return_intermediate:
intermediates.append(x)
self.update_progress(step + 1, len(timesteps_outer))
else:
raise ValueError(f"Got wrong method {method}")
if denoise_to_zero:
t = torch.ones((1,)).to(device) * t_0
x = self.denoise_to_zero_fn(x, t)
if self.correcting_xt_fn is not None:
x = self.correcting_xt_fn(x, t, step + 1)
if return_intermediate:
intermediates.append(x)
if return_intermediate:
return x, intermediates
else:
return x
#############################################################
# other utility functions
#############################################################
def interpolate_fn(x, xp, yp):
"""
A piecewise linear function y = f(x), using xp and yp as keypoints.
We implement f(x) in a differentiable way (i.e. applicable for autograd).
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
Args:
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
yp: PyTorch tensor with shape [C, K].
Returns:
The function values f(x), with shape [N, C].
"""
N, K = x.shape[0], xp.shape[1]
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
x_idx = torch.argmin(x_indices, dim=2)
cand_start_idx = x_idx - 1
start_idx = torch.where(
torch.eq(x_idx, 0),
torch.tensor(1, device=x.device),
torch.where(
torch.eq(x_idx, K),
torch.tensor(K - 2, device=x.device),
cand_start_idx,
),
)
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
start_idx2 = torch.where(
torch.eq(x_idx, 0),
torch.tensor(0, device=x.device),
torch.where(
torch.eq(x_idx, K),
torch.tensor(K - 2, device=x.device),
cand_start_idx,
),
)
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
return cand
def expand_dims(v, dims):
"""
Expand the tensor `v` to the dim `dims`.
Args:
`v`: a PyTorch tensor with shape [N].
`dim`: a `int`.
Returns:
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
"""
return v[(...,) + (None,) * (dims - 1)]
\ No newline at end of file
import torch as th
from torchdiffeq import odeint
from .utils import time_shift, get_lin_function
class sde:
"""SDE solver class"""
def __init__(
self,
drift,
diffusion,
*,
t0,
t1,
num_steps,
sampler_type,
):
assert t0 < t1, "SDE sampler has to be in forward time"
self.num_timesteps = num_steps
self.t = th.linspace(t0, t1, num_steps)
self.dt = self.t[1] - self.t[0]
self.drift = drift
self.diffusion = diffusion
self.sampler_type = sampler_type
def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
w_cur = th.randn(x.size()).to(x)
t = th.ones(x.size(0)).to(x) * t
dw = w_cur * th.sqrt(self.dt)
drift = self.drift(x, t, model, **model_kwargs)
diffusion = self.diffusion(x, t)
mean_x = x + drift * self.dt
x = mean_x + th.sqrt(2 * diffusion) * dw
return x, mean_x
def __Heun_step(self, x, _, t, model, **model_kwargs):
w_cur = th.randn(x.size()).to(x)
dw = w_cur * th.sqrt(self.dt)
t_cur = th.ones(x.size(0)).to(x) * t
diffusion = self.diffusion(x, t_cur)
xhat = x + th.sqrt(2 * diffusion) * dw
K1 = self.drift(xhat, t_cur, model, **model_kwargs)
xp = xhat + self.dt * K1
K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
return (
xhat + 0.5 * self.dt * (K1 + K2),
xhat,
) # at last time point we do not perform the heun step
def __forward_fn(self):
"""TODO: generalize here by adding all private functions ending with steps to it"""
sampler_dict = {
"Euler": self.__Euler_Maruyama_step,
"Heun": self.__Heun_step,
}
try:
sampler = sampler_dict[self.sampler_type]
except:
raise NotImplementedError("Smapler type not implemented.")
return sampler
def sample(self, init, model, **model_kwargs):
"""forward loop of sde"""
x = init
mean_x = init
samples = []
sampler = self.__forward_fn()
for ti in self.t[:-1]:
with th.no_grad():
x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
samples.append(x)
return samples
class ode:
"""ODE solver class"""
def __init__(
self,
drift,
*,
t0,
t1,
sampler_type,
num_steps,
atol,
rtol,
do_shift=False,
time_shifting_factor=None,
):
assert t0 < t1, "ODE sampler has to be in forward time"
self.drift = drift
self.do_shift = do_shift
self.t = th.linspace(t0, t1, num_steps)
if time_shifting_factor:
self.t = self.t / (self.t + time_shifting_factor - time_shifting_factor * self.t)
self.atol = atol
self.rtol = rtol
self.sampler_type = sampler_type
def sample(self, x, model, **model_kwargs):
x = x.float()
device = x[0].device if isinstance(x, tuple) else x.device
def _fn(t, x):
t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
model_output = self.drift(x, t, model, **model_kwargs).float()
return model_output
t = self.t.to(device)
if self.do_shift:
mu = get_lin_function(y1=0.5, y2=1.15)(x.shape[1])
t = time_shift(mu, 1.0, t)
atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
samples = odeint(_fn, x, t, method=self.sampler_type, atol=atol, rtol=rtol)
return samples
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment