Commit 4a457188 authored by raojy's avatar raojy
Browse files

fix

parent a570aeea
# LLaDA2.0-Uni ComfyUI Nodes
Custom ComfyUI nodes for [LLaDA 2.0-Uni](https://huggingface.co/inclusionAI/LLaDA2.0-Uni) — a unified multimodal diffusion language model supporting **text-to-image generation**, **image understanding (VQA)**, and **image editing**.
## Installation
> ⚠️ These nodes depend on the `encoder/` and `decoder/` modules in the project root. Do **not** copy `apps/comfyui` in isolation — the full repository must be present and the relative path `apps/comfyui` must be preserved.
### Option 1: Clone + symlink (recommended)
```bash
# 1. Clone the full project
git clone https://github.com/inclusionAI/LLaDA2.0-Uni.git
# 2. Symlink into ComfyUI's custom_nodes
cd /path/to/ComfyUI/custom_nodes
ln -s /path/to/LLaDA2.0-Uni/apps/comfyui ./LLaDA2Uni
```
### Option 2: One-line installer
```bash
bash /path/to/LLaDA2.0-Uni/apps/comfyui/install.sh /path/to/ComfyUI
```
### Dependencies
```bash
pip install -r apps/comfyui/requirements.txt
pip install flash-attn --no-build-isolation # optional, recommended
```
## Model Weights
In the Loader node, set the model path to either a HuggingFace repo ID or a local directory:
**HuggingFace (auto-download):**
```
inclusionAI/LLaDA2.0-Uni
```
**Local path:**
```
/path/to/LLaDA2.0-Uni
```
Expected directory layout:
```
LLaDA2.0-Uni/
├── config.json # LLM config
├── model-*.safetensors # LLM weights
├── tokenizer.json
├── decoder/
│ ├── config.json
│ └── model.safetensors # diffusion decoder
├── decoder-turbo/
│ ├── config.json
│ └── model.safetensors # turbo decoder (8-step)
├── vae/
│ └── diffusion_pytorch_model.safetensors
└── image_tokenizer/
├── config.json
├── preprocessor_config.json
├── model.safetensors # SigLIP-VQ weights
└── sigvq_embedding.pt
```
## Nodes
| Node | Description |
|------|-------------|
| **LLaDA2.0_Uni Loader** | Load the model (Flash Attention / SDPA, optional CPU offload) |
| **LLaDA2.0_Uni Text-to-Image** | Generate VQ image tokens from a text prompt (supports thinking mode) |
| **LLaDA2.0_Uni Image Understanding** | Visual question answering |
| **LLaDA2.0_Uni Image Editing** | Edit an image with a text instruction |
| **LLaDA2.0_Uni Token Decoder** | Decode VQ tokens to pixels (turbo or normal mode) |
| **LLaDA2.0_Uni Unload Model** | Manually free VRAM |
## Example Workflows
### Text-to-Image
```
Loader → Text-to-Image → Token Decoder → Preview Image
```
### Image Understanding
```
Load Image + Loader → Image Understanding → Show Text
```
### Image Editing
```
Load Image + Loader → Image Editing → Token Decoder → Preview Image
```
## Parameters
### Loader
- `model_path` — HuggingFace repo ID or local directory
- `attention``flash_attn` (recommended) or `sdpa`
- `dtype``bf16` (recommended) or `fp8`
- `offload` — enable CPU offload for limited VRAM
- `device``cuda` or `cpu`
### Text-to-Image
- `prompt` — text description
- `width` / `height` — output resolution
- `steps` — LLM denoising steps (8–32)
- `cfg_scale` — classifier-free guidance scale
- `mode``standard` or `thinking`
- `seed` — random seed (`-1` = random)
- `block_length` — block size for block-wise denoising
### Token Decoder
- `decode_mode``decoder-turbo` (fast, 8 steps) or `normal` (50 steps)
- `decoder_steps` — number of steps when using `normal` mode
- `resolution_multiplier` — upscale factor (typically `2`)
- `unload_after` — release decoder VRAM after decoding (set `False` to keep cached for faster repeated decodes)
## License
Same as the parent project. See the repository root for details.
"""
ComfyUI Custom Nodes for LLaDA2.0_Uni
Unified multimodal: Text-to-Image, Image Understanding (VQA), Image Editing
This node package lives inside LLaDA2.0_Uni/apps/comfyui/ and imports
encoder/decoder from the parent project directly.
"""
from .nodes import (
LLaDA2UniLoader,
LLaDA2UniTextToImage,
LLaDA2UniImageUnderstanding,
LLaDA2UniImageEditing,
LLaDA2UniImageDecode,
LLaDA2UniUnloadModel,
)
NODE_CLASS_MAPPINGS = {
"LLaDA2UniLoader": LLaDA2UniLoader,
"LLaDA2UniTextToImage": LLaDA2UniTextToImage,
"LLaDA2UniImageUnderstanding": LLaDA2UniImageUnderstanding,
"LLaDA2UniImageEditing": LLaDA2UniImageEditing,
"LLaDA2UniImageDecode": LLaDA2UniImageDecode,
"LLaDA2UniUnloadModel": LLaDA2UniUnloadModel,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LLaDA2UniLoader": "LLaDA2.0_Uni Loader",
"LLaDA2UniTextToImage": "LLaDA2.0_Uni Text-to-Image",
"LLaDA2UniImageUnderstanding": "LLaDA2.0_Uni Image Understanding",
"LLaDA2UniImageEditing": "LLaDA2.0_Uni Image Editing",
"LLaDA2UniImageDecode": "LLaDA2.0_Uni Token Decoder",
"LLaDA2UniUnloadModel": "LLaDA2.0_Uni Unload Model",
}
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
#!/bin/bash
# LLaDA2.0_Uni ComfyUI 节点快速安装脚本
set -e
echo "=========================================="
echo "LLaDA2.0_Uni ComfyUI 节点安装"
echo "=========================================="
# 检查 ComfyUI 路径
if [ -z "$1" ]; then
echo "用法: $0 /path/to/ComfyUI"
echo ""
echo "示例:"
echo " $0 ~/ComfyUI"
exit 1
fi
COMFYUI_PATH="$1"
CUSTOM_NODES_DIR="$COMFYUI_PATH/custom_nodes"
if [ ! -d "$COMFYUI_PATH" ]; then
echo "❌ ComfyUI 路径不存在: $COMFYUI_PATH"
exit 1
fi
if [ ! -d "$CUSTOM_NODES_DIR" ]; then
echo "❌ custom_nodes 目录不存在: $CUSTOM_NODES_DIR"
exit 1
fi
# 获取当前脚本所在目录(即 apps/comfyui)
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
# 验证项目结构完整(需要上层的 encoder/ 和 decoder/)
PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)"
if [ ! -d "$PROJECT_ROOT/encoder" ] || [ ! -d "$PROJECT_ROOT/decoder" ]; then
echo "❌ 项目结构不完整,缺少 encoder/ 或 decoder/ 目录"
echo " 项目根目录: $PROJECT_ROOT"
echo ""
echo "请确保通过 git clone 获取了完整项目:"
echo " git clone https://github.com/inclusionAI/LLaDA2.0-Uni.git"
exit 1
fi
echo "📂 ComfyUI 路径: $COMFYUI_PATH"
echo "📂 节点源路径: $SCRIPT_DIR"
echo "📂 项目根目录: $PROJECT_ROOT"
echo ""
# 创建软链接
TARGET_DIR="$CUSTOM_NODES_DIR/LLaDA2Uni"
if [ -e "$TARGET_DIR" ]; then
echo "⚠️ 目标已存在: $TARGET_DIR"
read -p "是否删除并重新创建?(y/n) " -n 1 -r
echo
if [[ $REPLY =~ ^[Yy]$ ]]; then
rm -rf "$TARGET_DIR"
else
echo "❌ 取消安装"
exit 1
fi
fi
ln -s "$SCRIPT_DIR" "$TARGET_DIR"
echo "✅ 创建软链接: $TARGET_DIR -> $SCRIPT_DIR"
# 安装依赖
echo ""
echo "📦 安装 Python 依赖..."
pip install -q -r "$SCRIPT_DIR/requirements.txt"
echo "✅ 依赖安装完成"
# 可选:Flash Attention
echo ""
read -p "是否安装 Flash Attention 2?(推荐,但编译较慢) (y/n) " -n 1 -r
echo
if [[ $REPLY =~ ^[Yy]$ ]]; then
echo "📦 安装 Flash Attention 2..."
pip install flash-attn --no-build-isolation || echo "⚠️ Flash Attention 安装失败,可以继续使用 SDPA"
fi
echo ""
echo "=========================================="
echo "✅ 安装完成!"
echo "=========================================="
echo ""
echo "下一步:"
echo "1. 启动 ComfyUI:"
echo " cd $COMFYUI_PATH && python main.py"
echo ""
echo "2. 在浏览器打开: http://localhost:8188"
echo ""
echo "3. 右键 → Add Node → LLaDA2.0_Uni"
echo ""
echo "4. 在 Loader 节点中设置模型路径:"
echo " inclusionAI/LLaDA2.0-Uni"
echo " (首次使用会自动从 HuggingFace 下载,也可填写本地路径)"
echo ""
echo "详细文档: $SCRIPT_DIR/README.md"
"""
Model manager for LLaDA2.0_Uni ComfyUI nodes.
Handles loading/unloading, attention backends, CPU offload, VRAM management,
and decoder model caching.
"""
import torch
import gc
import sys
import os
from typing import Dict, Any
# ── Add project root to sys.path so encoder/ and decoder/ are importable ──
_PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
if _PROJECT_ROOT not in sys.path:
sys.path.insert(0, _PROJECT_ROOT)
# ── Global state ──
_LLM_MODEL = None
_LLM_TOKENIZER = None
_IMAGE_TOKENIZER = None
_MODEL_PATH = None
_ATTENTION = None
_DEVICE = "cuda"
_OFFLOAD = False
_DTYPE = "bf16"
# Decoder cache (module-level)
_SIGVQ_MODEL = None
_DIFF_MODEL = None
_DIFF_MODE = None
_DIFF_CONFIG = None
_VAE_MODEL = None
_DECODER_MODEL_PATH = None
def _resolve_torch_dtype(dtype: str):
if dtype == "bf16":
return torch.bfloat16
if dtype == "fp8":
print("[LLaDA2.0_Uni] FP8 mode: using bf16 compute dtype for compatibility.")
return torch.bfloat16
raise ValueError(f"Unsupported dtype: {dtype}")
# ═══════════════════════════════════════════════════════════════
# LLM Loading
# ═══════════════════════════════════════════════════════════════
def load_llm(model_path: str, device: str = "cuda", attention: str = "flash_attn",
offload: bool = False, dtype: str = "bf16"):
"""Load the dLLM-MoE backbone. Returns (model, tokenizer)."""
global _LLM_MODEL, _LLM_TOKENIZER, _MODEL_PATH, _ATTENTION, _DEVICE, _OFFLOAD, _DTYPE
_ATTENTION = attention
_DEVICE = device
_OFFLOAD = offload
_DTYPE = dtype
if _LLM_MODEL is not None and _MODEL_PATH == model_path and _DTYPE == dtype:
return _LLM_MODEL, _LLM_TOKENIZER
unload_llm()
from transformers import AutoModelForCausalLM, AutoTokenizer
attn_kwargs = {"trust_remote_code": True}
if attention == "sdpa":
attn_kwargs["attn_implementation"] = "sdpa"
if offload:
attn_kwargs["device_map"] = "auto"
attn_kwargs["max_memory"] = {0: "20GiB", "cpu": "80GiB"}
attn_kwargs["offload_folder"] = "offload_cache"
attn_kwargs["torch_dtype"] = _resolve_torch_dtype(dtype)
else:
attn_kwargs["device_map"] = device
attn_kwargs["torch_dtype"] = _resolve_torch_dtype(dtype)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, **attn_kwargs).eval()
model.tokenizer = tokenizer
_LLM_MODEL = model
_LLM_TOKENIZER = tokenizer
_MODEL_PATH = model_path
return model, tokenizer
# ═══════════════════════════════════════════════════════════════
# Image Tokenizer
# ═══════════════════════════════════════════════════════════════
def get_image_tokenizer(model_path: str, device: str = "cuda"):
"""Load the SigLIP-VQ image tokenizer."""
global _IMAGE_TOKENIZER
if _IMAGE_TOKENIZER is None:
from encoder.image_tokenizer import ImageTokenizer
_IMAGE_TOKENIZER = ImageTokenizer(model_path=model_path, device=device)
return _IMAGE_TOKENIZER
# ═══════════════════════════════════════════════════════════════
# Decoder (with caching)
# ═══════════════════════════════════════════════════════════════
def decode_tokens(token_ids, h, w, model_path: str, device: str = "cuda",
num_steps: int = 50, decode_mode: str = "normal",
resolution_multiplier: int = 2, progress_callback=None):
"""Decode VQ tokens to PIL image, with model caching."""
global _SIGVQ_MODEL, _DIFF_MODEL, _DIFF_MODE, _DIFF_CONFIG, _VAE_MODEL, _DECODER_MODEL_PATH
import json
import torch.nn.functional as F
from tqdm import tqdm
from torchvision.transforms.functional import to_pil_image
from diffusers import AutoencoderKL
from safetensors.torch import load_file
from decoder.sigvq import SigVQ
from decoder.decoder_model import ZImageTransformer2DModel
from decoder.transport import create_transport, Sampler
dtype = torch.bfloat16
# ── Stage 1: SigVQ → semantic features (cached) ──
sigvq_path = os.path.join(model_path, "image_tokenizer", "sigvq_embedding.pt")
if _SIGVQ_MODEL is None or _DECODER_MODEL_PATH != model_path:
extractor = SigVQ(vocab_size=16384, inner_dim=4096).to(device, dtype=dtype)
extractor.load_state_dict(torch.load(sigvq_path, map_location=device, weights_only=True))
extractor.eval()
_SIGVQ_MODEL = extractor
_DECODER_MODEL_PATH = model_path
print("[LLaDA2.0_Uni Decoder] SigVQ loaded and cached.")
th = h * 16 * resolution_multiplier
tw = w * 16 * resolution_multiplier
tok = torch.tensor(token_ids).view(1, 1, h, w).float().to(device)
up = F.interpolate(tok, scale_factor=2, mode="nearest").long().view(1, -1)
cap_pos = [_SIGVQ_MODEL(up).squeeze(0)]
cap_neg = [torch.zeros_like(cap_pos[0])]
# ── Stage 2: Diffusion ODE sampling (cached) ──
if decode_mode == "decoder-turbo":
decoder_dir = os.path.join(model_path, "decoder-turbo")
else:
decoder_dir = os.path.join(model_path, "decoder")
if _DIFF_MODEL is None or _DIFF_MODE != decode_mode or _DECODER_MODEL_PATH != model_path:
# Free old model if mode changed
if _DIFF_MODEL is not None:
del _DIFF_MODEL
gc.collect()
torch.cuda.empty_cache()
config_path = os.path.join(decoder_dir, "config.json")
with open(config_path) as f:
cfg = json.load(f)
cfg["axes_lens"] = [32768, 1024, 1024]
cfg["cap_feat_dim"] = 4096
with torch.device("meta"):
diff_model = ZImageTransformer2DModel(**cfg)
ckpt = os.path.join(decoder_dir, "model.safetensors")
diff_model.load_state_dict(load_file(ckpt, device=str(device)), assign=True)
diff_model = diff_model.to(dtype=dtype).eval()
_DIFF_MODEL = diff_model
_DIFF_MODE = decode_mode
_DIFF_CONFIG = cfg
print(f"[LLaDA2.0_Uni Decoder] Diffusion model ({decode_mode}) loaded and cached.")
cfg = _DIFF_CONFIG
# Create model function for sampling
n = len(cap_pos)
doubled = cap_pos + cap_neg
cfg_scale = 0.0 if decode_mode == "decoder-turbo" else 1.0
patch_size = cfg.get("all_patch_size", (2,))[0]
f_patch_size = cfg.get("all_f_patch_size", (1,))[0]
def model_fn(x, t, **kw):
t_t = torch.tensor([t], device=x.device, dtype=torch.float32) if not isinstance(t, torch.Tensor) else t.float()
if t_t.dim() == 0: t_t = t_t.unsqueeze(0)
if t_t.shape[0] == 1 and x.shape[0] > 1: t_t = t_t.expand(x.shape[0])
if cfg_scale > 0:
out = _DIFF_MODEL(x=list(x.to(dtype).repeat(2, 1, 1, 1, 1).unbind(0)), t=t_t.repeat(2),
cap_feats=doubled, patch_size=patch_size, f_patch_size=f_patch_size, return_dict=False)
pos, neg = out[0][:n], out[0][n:]
res = []
for p, ng in zip(pos, neg):
p, ng = p.float(), ng.float()
pred = p + cfg_scale * (p - ng)
on, nn_ = torch.linalg.vector_norm(p), torch.linalg.vector_norm(pred)
if nn_ > on:
pred *= on / nn_
res.append(pred)
return torch.stack(res)
out = _DIFF_MODEL(x=list(x.to(dtype).unbind(0)), t=t_t, cap_feats=cap_pos,
patch_size=patch_size, f_patch_size=f_patch_size, return_dict=False)
return torch.stack([o.float() for o in out[0]])
z = torch.randn([1, 16, 1, 2 * (th // 16), 2 * (tw // 16)], device=device)
sampler = Sampler(create_transport("Linear", "velocity", None))
sample_fn = sampler.sample_ode(
sampling_method="euler", num_steps=num_steps,
atol=1e-6, rtol=1e-3, reverse=False, time_shifting_factor=6,
stochast_ratio=1.0 if decode_mode == "decoder-turbo" else 0.0)
step_counter = [0]
if progress_callback is not None:
def wrapped(x, t, **kw):
step_counter[0] += 1
progress_callback(step_counter[0], num_steps)
return model_fn(x, t, **kw)
else:
pbar = tqdm(total=num_steps, desc="Decoding", leave=False)
def wrapped(x, t, **kw):
pbar.update(1)
return model_fn(x, t, **kw)
with torch.inference_mode():
samples = sample_fn(z, wrapped)[-1].squeeze(2)
if progress_callback is None:
pbar.close()
# ── Stage 3: VAE decode (cached) ──
vae_dir = os.path.join(model_path, "vae")
if _VAE_MODEL is None or _DECODER_MODEL_PATH != model_path:
_VAE_MODEL = AutoencoderKL.from_pretrained(vae_dir, torch_dtype=dtype).to(device).eval()
print("[LLaDA2.0_Uni Decoder] VAE loaded and cached.")
with torch.inference_mode():
s = samples.to(dtype)
s = (s / _VAE_MODEL.config.scaling_factor) + _VAE_MODEL.config.shift_factor
px = ((_VAE_MODEL.decode(s, return_dict=False)[0] + 1) / 2).clamp_(0, 1)
return to_pil_image(px[0].float())
# ═══════════════════════════════════════════════════════════════
# Unload functions
# ═══════════════════════════════════════════════════════════════
def unload_llm():
"""Unload LLM backbone to free VRAM."""
global _LLM_MODEL, _LLM_TOKENIZER
if _LLM_MODEL is not None:
del _LLM_MODEL, _LLM_TOKENIZER
_LLM_MODEL = None
_LLM_TOKENIZER = None
gc.collect()
torch.cuda.empty_cache()
def unload_decoder():
"""Unload all decoder components from VRAM."""
global _SIGVQ_MODEL, _DIFF_MODEL, _DIFF_MODE, _DIFF_CONFIG, _VAE_MODEL, _DECODER_MODEL_PATH
for obj in (_SIGVQ_MODEL, _DIFF_MODEL, _VAE_MODEL):
if obj is not None:
del obj
_SIGVQ_MODEL = None
_DIFF_MODEL = None
_DIFF_MODE = None
_DIFF_CONFIG = None
_VAE_MODEL = None
_DECODER_MODEL_PATH = None
gc.collect()
torch.cuda.empty_cache()
def unload_image_tokenizer():
"""Unload image tokenizer."""
global _IMAGE_TOKENIZER
if _IMAGE_TOKENIZER is not None:
del _IMAGE_TOKENIZER
_IMAGE_TOKENIZER = None
gc.collect()
torch.cuda.empty_cache()
def unload_all():
"""Unload everything. Call this to free all VRAM."""
unload_llm()
unload_decoder()
unload_image_tokenizer()
"""
ComfyUI Custom Nodes for LLaDA2.0_Uni
Supports: Text-to-Image, Image Understanding, Image Editing, Decode, Unload
Features:
- Attention backend selection: Flash Attention, SDPA
- CPU offload mode for limited VRAM
- Auto VRAM cleanup after inference
- Decoder model caching (60-67% speedup on repeated decodes)
- Manual unload node
"""
import torch
import numpy as np
from PIL import Image
from typing import Optional
import os
import shutil
import subprocess
# ═══════════════════════════════════════════════════════════════
# LLaDA2.0_Uni Loader
# ═══════════════════════════════════════════════════════════════
class LLaDA2UniLoader:
"""Load LLaDA2.0_Uni model with attention & offload options."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model_path": ("STRING", {"default": "inclusionAI/LLaDA2.0-Uni"}),
"attention": (["flash_attn", "sdpa"],),
"dtype": (["bf16", "fp8"],),
"offload": ("BOOLEAN", {"default": False}),
"device": (["cuda", "cpu"],),
}
}
RETURN_TYPES = ("LLADA2UNI_MODEL",)
RETURN_NAMES = ("model",)
FUNCTION = "load"
CATEGORY = "LLaDA2.0_Uni"
DESCRIPTION = "Load LLaDA2.0_Uni model from specified path."
def load(self, model_path, attention, dtype, offload, device):
if not isinstance(device, str) or not device:
device = "cuda"
# Validate model path
if not os.path.isdir(model_path):
raise ValueError(f"Model path does not exist: {model_path}")
from .model_manager import load_llm
model, tokenizer = load_llm(model_path, device=device,
attention=attention, offload=offload, dtype=dtype)
return ({
"model_path": model_path,
"device": device,
"attention": attention,
"dtype": dtype,
"offload": offload,
},)
@classmethod
def IS_CHANGED(cls, **kwargs):
return float("NaN")
# ═══════════════════════════════════════════════════════════════
# Text-to-Image
# ═══════════════════════════════════════════════════════════════
class LLaDA2UniTextToImage:
"""Generate image tokens from text prompt."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("LLADA2UNI_MODEL",),
"prompt": ("STRING", {"multiline": True, "default": ""}),
"width": ("INT", {"default": 1024, "min": 256, "max": 2048, "step": 64}),
"height": ("INT", {"default": 1024, "min": 256, "max": 2048, "step": 64}),
"steps": ("INT", {"default": 8, "min": 1, "max": 32}),
"cfg_scale": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
},
"optional": {
"mode": (["standard", "thinking"],),
"thinking_steps": ("INT", {"default": 32, "min": 1, "max": 64}),
"thinking_length": ("INT", {"default": 4096, "min": 512, "max": 8192, "step": 512}),
"seed": ("INT", {"default": -1, "min": -1, "max": 2**32 - 1}),
"block_length": ("INT", {"default": 32, "min": 1, "max": 128}),
"unload_after": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("LLADA2UNI_TOKENS", "STRING",)
RETURN_NAMES = ("tokens", "thinking_output",)
FUNCTION = "generate"
CATEGORY = "LLaDA2.0_Uni"
DESCRIPTION = "Generate VQ image tokens from text prompt."
def generate(self, model, prompt, width, height, steps, cfg_scale,
mode="standard", thinking_steps=32, thinking_length=4096,
seed=-1, block_length=32, unload_after=True):
# Set seed for reproducibility
if seed >= 0:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
from .model_manager import load_llm, unload_llm
llm, _ = load_llm(model["model_path"], model["device"],
model.get("attention", "flash_attn"),
model.get("offload", False),
model.get("dtype", "bf16"))
# Generate image tokens
if mode == "thinking":
result = llm.generate_image(
prompt, image_h=height, image_w=width,
mode="thinking",
steps=steps, cfg_scale=cfg_scale,
thinking_steps=thinking_steps,
thinking_gen_length=thinking_length,
block_length=block_length,
)
thinking_text = result.get("thinking", "")
else:
result = llm.generate_image(
prompt, image_h=height, image_w=width,
steps=steps, cfg_scale=cfg_scale,
block_length=block_length,
)
thinking_text = ""
token_data = {
"token_ids": result["token_ids"],
"h": result["h"],
"w": result["w"],
"model_path": model["model_path"],
}
# Auto VRAM cleanup
if unload_after:
unload_llm()
return (token_data, thinking_text,)
# ═══════════════════════════════════════════════════════════════
# Image Understanding (VQA)
# ═══════════════════════════════════════════════════════════════
class LLaDA2UniImageUnderstanding:
"""Understand/describe an image using VQA."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("LLADA2UNI_MODEL",),
"image": ("IMAGE",),
"question": ("STRING", {"multiline": True, "default": "Describe this image in detail."}),
},
"optional": {
"gen_steps": ("INT", {"default": 32, "min": 1, "max": 64}),
"gen_length": ("INT", {"default": 2048, "min": 256, "max": 8192, "step": 256}),
"block_length": ("INT", {"default": 32, "min": 1, "max": 128}),
"unload_after": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("response",)
FUNCTION = "understand"
CATEGORY = "LLaDA2.0_Uni"
DESCRIPTION = "Ask questions about an image (VQA)."
def understand(self, model, image, question, gen_steps=32, gen_length=2048,
block_length=32, unload_after=True):
from .model_manager import load_llm, get_image_tokenizer, unload_llm
from decoder.smart_img_process import smart_resize_images
llm, _ = load_llm(model["model_path"], model["device"],
model.get("attention", "flash_attn"),
model.get("offload", False),
model.get("dtype", "bf16"))
# ComfyUI tensor (B,H,W,C) → PIL
img_np = (image[0].cpu().numpy() * 255).astype(np.uint8)
pil_image = Image.fromarray(img_np)
# Encode image
image_tokenizer = get_image_tokenizer(model["model_path"], model["device"])
pil_image = smart_resize_images([pil_image])[0]
info = image_tokenizer.encode_with_info(pil_image)
image_tokens = [x + llm.config.image_token_offset for x in info["token_ids"]]
_, h, w = info["grid_thw"]
# Understand the image
response = llm.understand_image(
image_tokens, h, w,
question=question,
steps=gen_steps,
gen_length=gen_length,
block_length=block_length,
)
# Auto VRAM cleanup
if unload_after:
unload_llm()
return (response,)
# ═══════════════════════════════════════════════════════════════
# Image Editing
# ═══════════════════════════════════════════════════════════════
class LLaDA2UniImageEditing:
"""Edit an image based on a text instruction."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("LLADA2UNI_MODEL",),
"image": ("IMAGE",),
"instruction": ("STRING", {"multiline": True, "default": ""}),
},
"optional": {
"steps": ("INT", {"default": 8, "min": 1, "max": 32}),
"cfg_text_scale": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.1}),
"cfg_image_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.1}),
"block_length": ("INT", {"default": 32, "min": 1, "max": 128}),
"seed": ("INT", {"default": -1, "min": -1, "max": 2**32 - 1}),
"unload_after": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("LLADA2UNI_TOKENS",)
RETURN_NAMES = ("tokens",)
FUNCTION = "edit"
CATEGORY = "LLaDA2.0_Uni"
DESCRIPTION = "Edit an image with a text instruction. Connect output to Decode node."
def edit(self, model, image, instruction, steps=8, cfg_text_scale=4.0,
cfg_image_scale=0.0, block_length=32, seed=-1, unload_after=True):
# Set seed for reproducibility
if seed >= 0:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
from .model_manager import load_llm, get_image_tokenizer, unload_llm
from decoder.utils import generate_crop_size_list, var_center_crop
llm, _ = load_llm(model["model_path"], model["device"],
model.get("attention", "flash_attn"),
model.get("offload", False),
model.get("dtype", "bf16"))
# ComfyUI tensor → PIL
img_np = (image[0].cpu().numpy() * 255).astype(np.uint8)
pil_image = Image.fromarray(img_np)
# Encode source image (preserve aspect ratio via var_center_crop)
image_tokenizer = get_image_tokenizer(model["model_path"], model["device"])
crop_size_list = generate_crop_size_list((512 // 32) ** 2, 32)
pil_image = var_center_crop(pil_image, crop_size_list=crop_size_list)
info = image_tokenizer.encode_with_info(pil_image)
image_tokens = [x + llm.config.image_token_offset for x in info["token_ids"]]
_, h, w = info["grid_thw"]
# Edit the image
result = llm.edit_image(
image_tokens, h, w, instruction,
steps=steps,
block_length=block_length,
cfg_text_scale=cfg_text_scale,
cfg_image_scale=cfg_image_scale,
)
token_data = {
"token_ids": result["token_ids"],
"h": result["h"],
"w": result["w"],
"model_path": model["model_path"],
}
# Auto VRAM cleanup
if unload_after:
unload_llm()
return (token_data,)
# ═══════════════════════════════════════════════════════════════
# Token Decoder
# ═══════════════════════════════════════════════════════════════
class LLaDA2UniImageDecode:
"""Decode VQ tokens from LLaDA2.0_Uni into a pixel image."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"tokens": ("LLADA2UNI_TOKENS",),
},
"optional": {
"decode_mode": (["decoder-turbo", "normal"],),
"decoder_steps": ("INT", {"default": 50, "min": 1, "max": 100}),
"resolution_multiplier": ("INT", {"default": 2, "min": 1, "max": 4}),
"unload_after": ("BOOLEAN", {"default": False}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
FUNCTION = "decode"
CATEGORY = "LLaDA2.0_Uni"
DESCRIPTION = "Decode VQ tokens to pixel image. decoder-turbo is ~10× faster (8 steps vs 50)."
def decode(self, tokens, decode_mode="decoder-turbo", decoder_steps=50,
resolution_multiplier=2, unload_after=False):
from .model_manager import decode_tokens, unload_all
num_steps = 8 if decode_mode == "decoder-turbo" else decoder_steps
# ComfyUI progress bar integration
progress_callback = None
try:
from comfy.utils import ProgressBar
pbar = ProgressBar(num_steps)
def progress_callback(step, total):
pbar.update(1)
except ImportError:
pass
pil_image = decode_tokens(
tokens["token_ids"],
tokens["h"],
tokens["w"],
tokens["model_path"],
device="cuda",
num_steps=num_steps,
decode_mode=decode_mode,
resolution_multiplier=resolution_multiplier,
progress_callback=progress_callback,
)
# PIL → ComfyUI tensor (1, H, W, 3) float32 [0,1]
img_np = np.array(pil_image).astype(np.float32) / 255.0
tensor = torch.from_numpy(img_np).unsqueeze(0)
# Auto VRAM cleanup after full pipeline
if unload_after:
unload_all()
return (tensor,)
# ═══════════════════════════════════════════════════════════════
# Manual Unload
# ═══════════════════════════════════════════════════════════════
class LLaDA2UniUnloadModel:
"""Manually unload all LLaDA2.0_Uni components and free VRAM."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {},
"optional": {},
}
RETURN_TYPES = ()
OUTPUT_NODE = True
FUNCTION = "unload"
CATEGORY = "LLaDA2.0_Uni"
DESCRIPTION = "Unload all LLaDA2.0_Uni models from VRAM."
def unload(self):
from .model_manager import unload_all
unload_all()
print("[LLaDA2.0_Uni] ✅ All models unloaded, VRAM freed")
return ()
torch>=2.4.0
torchvision>=0.19.0
transformers>=4.45.0
safetensors>=0.4.0
diffusers>=0.30.0
Pillow>=10.0.0
numpy>=1.24.0
tqdm>=4.65.0
accelerate>=0.26.0
torchdiffeq
from .sigvq import SigVQ
from .decoder_model import ZImageTransformer2DModel
from .decode import decode_vq_tokens
from .utils import generate_crop_size_list, var_center_crop
from . import transport
"""Decode VQ token IDs into a PIL Image via SigVQ + ZImage diffusion + VAE."""
import os
import json
import gc
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torchvision.transforms.functional import to_pil_image
from diffusers import AutoencoderKL
from safetensors.torch import load_file
from .sigvq import SigVQ
from .decoder_model import ZImageTransformer2DModel
from .transport import create_transport, Sampler
def _create_decoder_model_fn(model, cap_pos, cap_neg, cfg_scale, patch_size, f_patch_size, dtype):
n = len(cap_pos)
doubled = cap_pos + cap_neg
def fn(x, t, **kw):
t_t = torch.tensor([t], device=x.device, dtype=torch.float32) if not isinstance(t, torch.Tensor) else t.float()
if t_t.dim() == 0: t_t = t_t.unsqueeze(0)
if t_t.shape[0] == 1 and x.shape[0] > 1: t_t = t_t.expand(x.shape[0])
if cfg_scale > 0:
out = model(x=list(x.to(dtype).repeat(2, 1, 1, 1, 1).unbind(0)), t=t_t.repeat(2),
cap_feats=doubled, patch_size=patch_size, f_patch_size=f_patch_size, return_dict=False)
pos, neg = out[0][:n], out[0][n:]
res = []
for p, ng in zip(pos, neg):
p, ng = p.float(), ng.float()
pred = p + cfg_scale * (p - ng)
on, nn_ = torch.linalg.vector_norm(p), torch.linalg.vector_norm(pred)
if nn_ > on:
pred *= on / nn_
res.append(pred)
return torch.stack(res)
out = model(x=list(x.to(dtype).unbind(0)), t=t_t, cap_feats=cap_pos,
patch_size=patch_size, f_patch_size=f_patch_size, return_dict=False)
return torch.stack([o.float() for o in out[0]])
return fn
@torch.inference_mode()
def decode_vq_tokens(token_ids, h, w, model_path, device,
resolution_multiplier=2, num_steps=50,
decode_mode="normal"):
"""
Decode VQ token IDs into a PIL Image.
Args:
token_ids: List of VQ token IDs (without the +157184 offset).
h, w: Semantic grid size (image_pixels // 16).
model_path: Root path of the model directory.
device: torch device.
resolution_multiplier: Upscale factor (2 = 1024px from 512px tokens).
num_steps: ODE sampling steps.
decode_mode: ``"normal"`` uses the standard decoder (default, 50 steps);
``"decoder-turbo"`` uses the distilled decoder (faster, ~8 steps).
Returns:
PIL.Image
"""
dtype = torch.bfloat16
sigvq_path = os.path.join(model_path, "image_tokenizer", "sigvq_embedding.pt")
if decode_mode == "decoder-turbo":
decoder_dir = os.path.join(model_path, "decoder-turbo")
else:
decoder_dir = os.path.join(model_path, "decoder")
vae_dir = os.path.join(model_path, "vae")
# ---------- Stage 1: SigVQ → semantic features ----------
extractor = SigVQ(vocab_size=16384, inner_dim=4096).to(device, dtype=dtype)
extractor.load_state_dict(
torch.load(sigvq_path, map_location=device, weights_only=True))
extractor.eval()
th = h * 16 * resolution_multiplier
tw = w * 16 * resolution_multiplier
tok = torch.tensor(token_ids).view(1, 1, h, w).float().to(device)
up = F.interpolate(tok, scale_factor=2, mode="nearest").long().view(1, -1)
cap_pos = [extractor(up).squeeze(0)]
cap_neg = [torch.zeros_like(cap_pos[0])]
# SigVQ is no longer needed — release immediately
del extractor
gc.collect()
torch.cuda.empty_cache()
# ---------- Stage 2: Diffusion ODE sampling ----------
config_path = os.path.join(decoder_dir, "config.json")
with open(config_path) as f:
cfg = json.load(f)
cfg["axes_lens"] = [32768, 1024, 1024]
cfg["cap_feat_dim"] = 4096
# Build model on meta device, load weights directly to GPU, then tie —
# avoids the ~12 GB peak from holding both random init + loaded weights.
with torch.device("meta"):
diff_model = ZImageTransformer2DModel(**cfg)
ckpt = os.path.join(decoder_dir, "model.safetensors")
diff_model.load_state_dict(load_file(ckpt, device=str(device)), assign=True)
diff_model = diff_model.to(dtype=dtype).eval()
z = torch.randn([1, 16, 1, 2 * (th // 16), 2 * (tw // 16)], device=device)
model_fn = _create_decoder_model_fn(
diff_model, cap_pos, cap_neg,
cfg_scale=0.0 if decode_mode == "decoder-turbo" else 1.0,
patch_size=cfg.get("all_patch_size", (2,))[0],
f_patch_size=cfg.get("all_f_patch_size", (1,))[0],
dtype=dtype)
sampler = Sampler(create_transport("Linear", "velocity", None))
sample_fn = sampler.sample_ode(
sampling_method="euler", num_steps=num_steps,
atol=1e-6, rtol=1e-3, reverse=False, time_shifting_factor=6,
stochast_ratio=1.0 if decode_mode == "decoder-turbo" else 0.0)
pbar = tqdm(total=num_steps, desc="Decoding", leave=False)
def wrapped(x, t, **kw):
pbar.update(1)
return model_fn(x, t, **kw)
samples = sample_fn(z, wrapped)[-1].squeeze(2)
pbar.close()
# Diffusion model is done — release before loading VAE
del diff_model, cap_pos, cap_neg, model_fn
gc.collect()
torch.cuda.empty_cache()
# ---------- Stage 3: VAE decode ----------
vae = AutoencoderKL.from_pretrained(vae_dir, torch_dtype=dtype).to(device).eval()
s = samples.to(dtype)
s = (s / vae.config.scaling_factor) + vae.config.shift_factor
px = ((vae.decode(s, return_dict=False)[0] + 1) / 2).clamp_(0, 1)
del vae
gc.collect()
torch.cuda.empty_cache()
return to_pil_image(px[0].float())
This diff is collapsed.
"""SigVQ: Semantic token embedding extractor for the image decoder."""
import torch
import torch.nn as nn
class _LinearWrapper(nn.Module):
"""Wraps nn.Linear inside a .proj attribute to match diffusers checkpoint key format."""
def __init__(self, in_features, out_features):
super().__init__()
self.proj = nn.Linear(in_features, out_features)
def forward(self, x):
return self.proj(x)
class _FeedForward(nn.Module):
"""SiLU feed-forward matching diffusers key layout: net.0.proj / net.1 / net.2"""
def __init__(self, dim: int, hidden_dim: int):
super().__init__()
self.net = nn.Sequential(
_LinearWrapper(dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class SigVQ(nn.Module):
"""
Lightweight semantic token extractor.
Maps discrete VQ token IDs to continuous feature vectors via embedding + projection.
Args:
vocab_size: VQ codebook size (default: 16384).
inner_dim: Feature dimension (default: 4096).
"""
def __init__(self, vocab_size: int = 16384, inner_dim: int = 4096):
super().__init__()
self.prior_token_embedding = nn.Embedding(vocab_size, inner_dim)
self.prior_projector = _FeedForward(dim=inner_dim, hidden_dim=inner_dim)
self.requires_grad_(False)
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
"""
Args:
token_ids: (batch, seq_len) discrete token indices.
Returns:
(batch, seq_len, inner_dim) projected feature vectors.
"""
return self.prior_projector(self.prior_token_embedding(token_ids))
"""Smart image resizing with aspect-ratio preservation and factor alignment."""
import math
from typing import List, Tuple
from PIL import Image
def smart_resize(
height: int,
width: int,
min_pixels: int,
max_pixels: int,
factor: int = 32,
) -> Tuple[int, int]:
"""
Qwen2.5-VL style smart resize.
Scales the image to fit within [min_pixels, max_pixels] while preserving
the aspect ratio, and returns target dimensions aligned to ``factor``.
"""
h_bar = max(round(height / factor) * factor, factor)
w_bar = max(round(width / factor) * factor, factor)
if h_bar * w_bar > max_pixels:
scale = math.sqrt(max_pixels / (height * width))
h_bar = max(math.floor(height * scale / factor) * factor, factor)
w_bar = max(math.floor(width * scale / factor) * factor, factor)
elif h_bar * w_bar < min_pixels:
scale = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * scale / factor) * factor
w_bar = math.ceil(width * scale / factor) * factor
return h_bar, w_bar
def resize_and_center_crop(
img: Image.Image,
target_h: int,
target_w: int,
factor: int = 32,
) -> Image.Image:
"""
Resize the image (preserving aspect ratio) so that it covers the target
dimensions, then center-crop to a factor-aligned size.
"""
width, height = img.size
# Scale so the image just covers the target area
scale = max(target_h / height, target_w / width)
new_h = int(round(height * scale))
new_w = int(round(width * scale))
img = img.resize((new_w, new_h), resample=Image.BICUBIC)
# Center-crop to factor-aligned dimensions
crop_h = (new_h // factor) * factor
crop_w = (new_w // factor) * factor
# Ensure at least target size
crop_h = max(crop_h, target_h)
crop_w = max(crop_w, target_w)
top = (new_h - crop_h) // 2
left = (new_w - crop_w) // 2
img = img.crop((left, top, left + crop_w, top + crop_h))
return img
def smart_resize_images(
image_paths: List[str],
patch_size: int = 16,
merge_size: int = 2,
single_min_pixels: int = 128 * 128,
single_max_pixels: int = 800 * 800,
multi_min_pixels: int = 128 * 128,
multi_max_pixels: int = 448 * 448,
) -> List[Image.Image]:
"""
Smart-resize a list of images for model input.
Uses larger resolution limits for single-image inputs and smaller limits
for multi-image inputs to control total token count.
"""
num_images = len(image_paths)
if num_images == 0:
return []
factor = patch_size * merge_size # 32
if num_images == 1:
min_pixels = single_min_pixels
max_pixels = single_max_pixels
else:
min_pixels = multi_min_pixels
max_pixels = multi_max_pixels
images = []
for path in image_paths:
if path is None:
images.append(path)
continue
img = Image.open(path).convert("RGB")
width, height = img.size
target_h, target_w = smart_resize(height, width, min_pixels, max_pixels, factor)
img = resize_and_center_crop(img, target_h, target_w, factor)
images.append(img)
return images
from .transport import ModelType, PathType, Sampler, Transport, WeightType
def create_transport(
path_type="Linear",
prediction="velocity",
loss_weight=None,
train_eps=None,
sample_eps=None,
snr_type="uniform",
do_shift=True,
seq_len=1024, # corresponding to 512x512
):
"""function for creating Transport object
**Note**: model prediction defaults to velocity
Args:
- path_type: type of path to use; default to linear
- learn_score: set model prediction to score
- learn_noise: set model prediction to noise
- velocity_weighted: weight loss by velocity weight
- likelihood_weighted: weight loss by likelihood weight
- train_eps: small epsilon for avoiding instability during training
- sample_eps: small epsilon for avoiding instability during sampling
"""
if prediction == "noise":
model_type = ModelType.NOISE
elif prediction == "score":
model_type = ModelType.SCORE
else:
model_type = ModelType.VELOCITY
if loss_weight == "velocity":
loss_type = WeightType.VELOCITY
elif loss_weight == "likelihood":
loss_type = WeightType.LIKELIHOOD
else:
loss_type = WeightType.NONE
path_choice = {
"Linear": PathType.LINEAR,
"GVP": PathType.GVP,
"VP": PathType.VP,
}
path_type = path_choice[path_type]
if path_type in [PathType.VP]:
train_eps = 1e-5 if train_eps is None else train_eps
sample_eps = 1e-3 if train_eps is None else sample_eps
elif path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY:
train_eps = 1e-3 if train_eps is None else train_eps
sample_eps = 1e-3 if train_eps is None else sample_eps
else: # velocity & [GVP, LINEAR] is stable everywhere
train_eps = 0
sample_eps = 0
# create flow state
state = Transport(
model_type=model_type,
path_type=path_type,
loss_type=loss_type,
train_eps=train_eps,
sample_eps=sample_eps,
snr_type=snr_type,
do_shift=do_shift,
seq_len=seq_len,
)
return state
This diff is collapsed.
import torch as th
from torchdiffeq import odeint
from .utils import time_shift, get_lin_function
class sde:
"""SDE solver class"""
def __init__(
self,
drift,
diffusion,
*,
t0,
t1,
num_steps,
sampler_type,
):
assert t0 < t1, "SDE sampler has to be in forward time"
self.num_timesteps = num_steps
self.t = th.linspace(t0, t1, num_steps)
self.dt = self.t[1] - self.t[0]
self.drift = drift
self.diffusion = diffusion
self.sampler_type = sampler_type
def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
w_cur = th.randn(x.size()).to(x)
t = th.ones(x.size(0)).to(x) * t
dw = w_cur * th.sqrt(self.dt)
drift = self.drift(x, t, model, **model_kwargs)
diffusion = self.diffusion(x, t)
mean_x = x + drift * self.dt
x = mean_x + th.sqrt(2 * diffusion) * dw
return x, mean_x
def __Heun_step(self, x, _, t, model, **model_kwargs):
w_cur = th.randn(x.size()).to(x)
dw = w_cur * th.sqrt(self.dt)
t_cur = th.ones(x.size(0)).to(x) * t
diffusion = self.diffusion(x, t_cur)
xhat = x + th.sqrt(2 * diffusion) * dw
K1 = self.drift(xhat, t_cur, model, **model_kwargs)
xp = xhat + self.dt * K1
K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
return (
xhat + 0.5 * self.dt * (K1 + K2),
xhat,
) # at last time point we do not perform the heun step
def __forward_fn(self):
"""TODO: generalize here by adding all private functions ending with steps to it"""
sampler_dict = {
"Euler": self.__Euler_Maruyama_step,
"Heun": self.__Heun_step,
}
try:
sampler = sampler_dict[self.sampler_type]
except:
raise NotImplementedError("Smapler type not implemented.")
return sampler
def sample(self, init, model, **model_kwargs):
"""forward loop of sde"""
x = init
mean_x = init
samples = []
sampler = self.__forward_fn()
for ti in self.t[:-1]:
with th.no_grad():
x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
samples.append(x)
return samples
class ode:
"""ODE solver class"""
def __init__(
self,
drift,
*,
t0,
t1,
sampler_type,
num_steps,
atol,
rtol,
do_shift=False,
time_shifting_factor=None,
):
assert t0 < t1, "ODE sampler has to be in forward time"
self.drift = drift
self.do_shift = do_shift
self.t = th.linspace(t0, t1, num_steps)
if time_shifting_factor:
self.t = self.t / (self.t + time_shifting_factor - time_shifting_factor * self.t)
self.atol = atol
self.rtol = rtol
self.sampler_type = sampler_type
def sample(self, x, model, **model_kwargs):
x = x.float()
device = x[0].device if isinstance(x, tuple) else x.device
def _fn(t, x):
t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
model_output = self.drift(x, t, model, **model_kwargs).float()
return model_output
t = self.t.to(device)
if self.do_shift:
mu = get_lin_function(y1=0.5, y2=1.15)(x.shape[1])
t = time_shift(mu, 1.0, t)
atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
samples = odeint(_fn, x, t, method=self.sampler_type, atol=atol, rtol=rtol)
return samples
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment