from __future__ import annotations import argparse import json import math import random from pathlib import Path from typing import Sequence import numpy as np import torch from PIL import Image import sensenova_u1 from sensenova_u1.utils import ( DEFAULT_IMAGE_PATCH_SIZE, DEFAULT_VRAM_MODE, InferenceProfiler, add_offload_args, best_available_device, load_and_merge_lora_weight_from_safetensors, load_model_and_tokenizer, make_offload_ctx, seed_all_accelerators, vram_mode_to_prefetch_count, ) NORM_MEAN = (0.5, 0.5, 0.5) NORM_STD = (0.5, 0.5, 0.5) DEFAULT_SEED = 42 SUPPORTED_RESOLUTIONS: dict[str, tuple[int, int]] = { "1:1": (1536, 1536), "16:9": (2048, 1152), "9:16": (1152, 2048), "3:2": (1888, 1248), "2:3": (1248, 1888), "4:3": (1760, 1312), "3:4": (1312, 1760), "1:2": (1088, 2144), "2:1": (2144, 1088), "1:3": (864, 2592), "3:1": (2592, 864), } DEFAULT_RESOLUTION = "16:9" DEFAULT_WIDTH, DEFAULT_HEIGHT = SUPPORTED_RESOLUTIONS[DEFAULT_RESOLUTION] def _warn_if_unsupported(width: int, height: int) -> None: if (width, height) in SUPPORTED_RESOLUTIONS.values(): return buckets = ", ".join(f"{r}->{w}x{h}" for r, (w, h) in SUPPORTED_RESOLUTIONS.items()) print( f"[warn] ({width}x{height}) is outside the trained resolution set; " f"quality may degrade. Supported buckets: {buckets}" ) # Interleave inference requires a system prompt that describes the # think / no-think protocol expected by the model during training. DEFAULT_SYSTEM_MESSAGE = """You are a multimodal assistant capable of reasoning with both text and images. You support two modes:\n\nThink Mode: When reasoning is needed, you MUST start with a block and place all reasoning inside it. You MUST interleave text with generated images using tags like , . Images can ONLY be generated between and , and may be referenced in the final answer.\n\nNon-Think Mode: When no reasoning is needed, directly provide the answer without reasoning. Do not use tags like , ; present any images naturally alongside the text.\n\nAfter the think block, always provide a concise, user-facing final answer. The answer may include text, images, or both. Match the user's language in both reasoning and the final answer.""" def _set_seed(seed: int) -> None: """Make sampling reproducible across python / numpy / torch (+ every available accelerator backend).""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) seed_all_accelerators(seed) def _round_by(n: int, factor: int) -> int: return round(n / factor) * factor def _ceil_by(n: int, factor: int) -> int: return math.ceil(n / factor) * factor def _floor_by(n: int, factor: int) -> int: return math.floor(n / factor) * factor def smart_resize( height: int, width: int, factor: int = 32, min_pixels: int = 512 * 512, max_pixels: int = (4 * 2048 * 2048) // 8, ) -> tuple[int, int]: """Return ``(h, w)`` that are divisible by ``factor``, keep aspect ratio, and fall inside ``[min_pixels, max_pixels]``. Adapted from the Qwen2.5-VL utility used by the training pipeline so generated-image sizes stay in the buckets the model saw during SFT. """ if max(height, width) / max(1, min(height, width)) > 200: raise ValueError(f"absolute aspect ratio must be < 200, got {max(height, width) / min(height, width)}") h_bar = max(factor, _round_by(height, factor)) w_bar = max(factor, _round_by(width, factor)) if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) h_bar = max(factor, _floor_by(height / beta, factor)) w_bar = max(factor, _floor_by(width / beta, factor)) elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = _ceil_by(height * beta, factor) w_bar = _ceil_by(width * beta, factor) return h_bar, w_bar def _denorm(x: torch.Tensor) -> torch.Tensor: mean = torch.tensor(NORM_MEAN, device=x.device, dtype=x.dtype).view(1, 3, 1, 1) std = torch.tensor(NORM_STD, device=x.device, dtype=x.dtype).view(1, 3, 1, 1) return (x * std + mean).clamp(0, 1) def _to_pil(batch: torch.Tensor) -> Image.Image: """Convert a single [1, 3, H, W] normalized tensor to a PIL image.""" arr = _denorm(batch.float()).permute(0, 2, 3, 1).cpu().numpy() arr = (arr * 255.0).round().astype(np.uint8) return Image.fromarray(arr[0]) class SenseNovaU1Interleave: """Thin wrapper around ``AutoModel.from_pretrained`` for interleaved text+image generation. Because ``sensenova_u1`` has already registered the config / model with transformers at import time, no ``trust_remote_code=True`` is needed. """ def __init__( self, model_path: str, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, gguf_checkpoint: str | None = None, device_map: str | None = None, max_memory: str | None = None, vram_mode: str = DEFAULT_VRAM_MODE, ) -> None: self.device = device self.vram_mode = vram_mode self.prefetch_count = vram_mode_to_prefetch_count(vram_mode) self.model, self.tokenizer = load_model_and_tokenizer( model_path, dtype=dtype, device=device, gguf_checkpoint=gguf_checkpoint, for_offload=self.prefetch_count > 0, device_map=device_map, max_memory=max_memory, ) @torch.inference_mode() def generate( self, prompt: str, input_images: Sequence[Image.Image] = (), image_size: tuple[int, int] = (DEFAULT_WIDTH, DEFAULT_HEIGHT), cfg_scale: float = 4.0, img_cfg_scale: float = 1.0, timestep_shift: float = 3.0, cfg_interval: tuple[float, float] = (0.0, 1.0), num_steps: int = 50, think_mode: bool = True, system_message: str = DEFAULT_SYSTEM_MESSAGE, seed: int = 0, ) -> tuple[str, list[Image.Image]]: with make_offload_ctx(self.model, self.prefetch_count, self.device) as offloaded: text, image_tensors = offloaded.interleave_gen( self.tokenizer, prompt, images=list(input_images), image_size=image_size, cfg_scale=cfg_scale, img_cfg_scale=img_cfg_scale, timestep_shift=timestep_shift, cfg_interval=cfg_interval, num_steps=num_steps, system_message=system_message, think_mode=think_mode, seed=seed, verbose=True, ) return text, [_to_pil(img) for img in image_tensors] def _load_input_images(paths: Sequence[str], image_root: str = "") -> list[Image.Image]: """Load images from ``paths``. When ``image_root`` is set, it is prepended to any non-absolute path; absolute paths are used as-is.""" images: list[Image.Image] = [] for p in paths: resolved = p if not image_root or Path(p).is_absolute() else str(Path(image_root) / p) if not Path(resolved).exists(): raise FileNotFoundError(f"input image not found: {resolved}") images.append(Image.open(resolved).convert("RGB")) return images def _resolve_image_size( input_images: Sequence[Image.Image], fallback_w: int, fallback_h: int, ) -> tuple[int, int]: """Pick generation (W, H). With input images, follow the first one so edits stay aligned (snapped to 32-aligned buckets via ``smart_resize``). Without input images, use the caller-provided fallback as-is — it is expected to already be one of ``SUPPORTED_RESOLUTIONS``.""" if input_images: w, h = input_images[0].size resized_h, resized_w = smart_resize(h, w) return resized_w, resized_h return fallback_w, fallback_h def _save_outputs( text: str, images: Sequence[Image.Image], out_dir: Path, stem: str, input_images: Sequence[Image.Image] = (), prompt: str = "", ) -> list[str]: """Persist the prompt + model output + generated images and (optionally) the user-supplied input images so a result can be reproduced from disk alone. Returns the relative filenames of saved input images.""" out_dir.mkdir(parents=True, exist_ok=True) text_path = out_dir / f"{stem}.txt" if prompt: text_path.write_text( f"# PROMPT\n{prompt}\n\n# OUTPUT\n{text}\n", encoding="utf-8", ) else: text_path.write_text(text, encoding="utf-8") print(f"[saved] {text_path}") input_names: list[str] = [] for i, img in enumerate(input_images): name = f"{stem}_input_{i}.png" img_path = out_dir / name img.save(img_path) input_names.append(name) print(f"[saved] {img_path}") for i, img in enumerate(images): img_path = out_dir / f"{stem}_image_{i}.png" img.save(img_path) print(f"[saved] {img_path}") return input_names def _sample_images(sample: dict, image_root: str = "") -> list[Image.Image]: """Load ``sample['image']`` (or ``sample['images']``). Relative paths are resolved against ``image_root`` when provided. Missing key is treated as no input images.""" paths = sample.get("image") or sample.get("images") or [] return _load_input_images(paths, image_root=image_root) def _extract_prompt(sample: dict) -> str: """Accept both flat ``{"prompt": ...}`` and ShareGPT-style ``{"conversations": [{"from": "human", "value": ...}, ...]}``.""" if "prompt" in sample: return sample["prompt"] for conv in sample.get("conversations", []): if conv.get("from") == "human": return conv["value"] raise ValueError("sample has no 'prompt' and no human turn in 'conversations'") def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="Interleaved text+image inference for SenseNova-U1.") p.add_argument( "--model_path", required=True, help="HuggingFace Hub id (e.g. sensenova/SenseNova-U1-8B-MoT) or a local path.", ) p.add_argument( "--lora_path", required=False, default=None, help="HuggingFace Hub id or a local path to a lora model.", ) src = p.add_mutually_exclusive_group(required=True) src.add_argument("--prompt", help="Generate from a single text prompt.") src.add_argument( "--jsonl", help=( 'JSONL file, one sample per line. Required: {"prompt": ...} or ' '{"conversations": [{"from": "human", "value": ...}, ...]}. ' 'Optional: {"image": [paths], "width": W, "height": H, "seed": S, ' '"think_mode": bool}. ' "If 'image' is set, output size follows the first input image " "(via smart_resize); 'width'/'height' are used only for " "text-only samples." ), ) p.add_argument( "--image", action="append", default=[], help=( "Path to an input image (repeatable). Only valid with --prompt. " "The prompt should contain a matching '' placeholder per image." ), ) p.add_argument( "--image_root", default="", help=( "Directory prepended to relative image paths in --jsonl samples. " "Absolute paths (in the jsonl or --image) are used as-is." ), ) p.add_argument("--output_dir", default="outputs", help="Directory for generated text + images.") p.add_argument( "--stem", default="sample", help="Filename stem when using --prompt. Generated files are .txt and _image_.png.", ) p.add_argument( "--resolution", default=DEFAULT_RESOLUTION, choices=list(SUPPORTED_RESOLUTIONS.keys()), help=( f"Aspect-ratio bucket used when no input image is provided " f"(default: {DEFAULT_RESOLUTION} -> " f"{SUPPORTED_RESOLUTIONS[DEFAULT_RESOLUTION][0]}x" f"{SUPPORTED_RESOLUTIONS[DEFAULT_RESOLUTION][1]}). " "Overridden by --width/--height when both are set." ), ) p.add_argument( "--width", type=int, default=None, help="Explicit fallback width. Overrides --resolution when both --width and --height are set.", ) p.add_argument( "--height", type=int, default=None, help="Explicit fallback height. Overrides --resolution when both --width and --height are set.", ) p.add_argument("--cfg_scale", type=float, default=4.0) p.add_argument("--img_cfg_scale", type=float, default=1.0) p.add_argument("--timestep_shift", type=float, default=3.0) p.add_argument( "--cfg_interval", type=float, nargs=2, default=[0.0, 1.0], metavar=("LO", "HI"), ) p.add_argument("--num_steps", type=int, default=50) p.add_argument( "--think_mode", action=argparse.BooleanOptionalAction, default=True, help="Enable reasoning before the final answer. On by default; pass --no-think_mode to disable.", ) p.add_argument( "--system_message", default=DEFAULT_SYSTEM_MESSAGE, help="Override the default interleave system prompt.", ) p.add_argument( "--seed", type=int, default=DEFAULT_SEED, help=( f"Random seed for reproducible sampling (default: {DEFAULT_SEED}). " "In --jsonl mode, a per-sample `seed` field overrides this." ), ) p.add_argument( "--device", default=str(best_available_device()), help="Compute device, e.g. 'cuda', 'cuda:0', 'xpu', 'xpu:0', 'cpu'. Defaults to the best available accelerator.", ) p.add_argument( "--dtype", default="bfloat16", choices=["bfloat16", "float16", "float32"], ) add_offload_args(p) p.add_argument( "--gguf_checkpoint", default=None, help=( "Optional path to a .gguf quantized checkpoint. When set, the dequantizing " "diffusers GGUF Linear layer is used instead of safetensors weights. " "Requires the [gguf] extra (gguf>=0.10.0, diffusers>=0.30.0)." ), ) p.add_argument( "--attn_backend", default="auto", choices=["auto", "flash", "sdpa"], help=( "Attention kernel used by the Qwen3 layers. 'auto' picks flash-attn when importable and falls back to SDPA." ), ) p.add_argument( "--profile", action="store_true", help=( "Print timing and CUDA memory stats: model load time, average " "per-image generation time, peak GPU memory, and the same time " f"normalized per image token (patch size = {DEFAULT_IMAGE_PATCH_SIZE})." ), ) return p.parse_args() def main() -> None: args = parse_args() dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype] sensenova_u1.set_attn_backend(args.attn_backend) print(f"[attn] backend={args.attn_backend!r} (effective={sensenova_u1.effective_attn_backend()!r})") profiler = InferenceProfiler( enabled=args.profile, device=args.device, config={ "vram_mode": args.vram_mode, "attn_backend": sensenova_u1.effective_attn_backend(), "dtype": args.dtype, "gguf": args.gguf_checkpoint, }, ) with profiler.time_load(): engine = SenseNovaU1Interleave( args.model_path, device=args.device, dtype=dtype, gguf_checkpoint=args.gguf_checkpoint, device_map=args.device_map, max_memory=args.max_memory, vram_mode=args.vram_mode, ) if args.lora_path is not None: print(f"load lora {args.lora_path}") engine.model = load_and_merge_lora_weight_from_safetensors(engine.model, args.lora_path) cfg_interval = tuple(args.cfg_interval) out_dir = Path(args.output_dir) if args.width is not None and args.height is not None: fallback_w, fallback_h = args.width, args.height _warn_if_unsupported(fallback_w, fallback_h) else: fallback_w, fallback_h = SUPPORTED_RESOLUTIONS[args.resolution] # Single-sample inference: --prompt + optional --image (repeatable). if args.prompt is not None: print("prompt:", args.prompt) input_images = _load_input_images(args.image) w, h = _resolve_image_size(input_images, fallback_w, fallback_h) # _set_seed(args.seed) with profiler.time_generate(w, h, 1) as gen: text, images = engine.generate( args.prompt, input_images=input_images, image_size=(w, h), cfg_scale=args.cfg_scale, img_cfg_scale=args.img_cfg_scale, timestep_shift=args.timestep_shift, cfg_interval=cfg_interval, num_steps=args.num_steps, think_mode=args.think_mode, system_message=args.system_message, seed=args.seed, ) profiler.update_last_batch(len(images)) print(f"[text] {text}") _save_outputs( text, images, out_dir, args.stem, input_images=input_images, prompt=args.prompt, ) profiler.report() return # Batch inference: one sample per line in --jsonl. with open(args.jsonl) as f: samples = [json.loads(line) for line in f if line.strip()] try: from tqdm import tqdm except ImportError: def tqdm(x, **_kw): # type: ignore[no-redef] return x results_path = out_dir / "results.jsonl" out_dir.mkdir(parents=True, exist_ok=True) with open(results_path, "w", encoding="utf-8") as rf: for i, sample in enumerate(tqdm(samples, desc="interleave")): prompt = _extract_prompt(sample) input_images = _sample_images(sample, image_root=args.image_root) if input_images: # When the sample ships input images, always follow their # size (via smart_resize); any per-sample width/height is # treated as a no-input-image fallback only. w, h = _resolve_image_size(input_images, fallback_w, fallback_h) elif "width" in sample and "height" in sample: w, h = int(sample["width"]), int(sample["height"]) _warn_if_unsupported(w, h) else: w, h = fallback_w, fallback_h think_mode = bool(sample.get("think_mode", args.think_mode)) # _set_seed(int(sample.get("seed", args.seed))) with profiler.time_generate(w, h, 1) as gen: text, images = engine.generate( prompt, input_images=input_images, image_size=(w, h), cfg_scale=args.cfg_scale, img_cfg_scale=args.img_cfg_scale, timestep_shift=args.timestep_shift, cfg_interval=cfg_interval, num_steps=args.num_steps, think_mode=think_mode, system_message=args.system_message, seed=args.seed, ) profiler.update_last_batch(len(images)) stem = f"{i + 1:04d}" + ("_think" if think_mode else "_no_think") input_names = _save_outputs( text, images, out_dir, stem, input_images=input_images, prompt=prompt, ) rf.write( json.dumps( { "index": i, "prompt": prompt, "text": text, "input_images": input_names, "images": [f"{stem}_image_{j}.png" for j in range(len(images))], "width": w, "height": h, "think_mode": think_mode, }, ensure_ascii=False, ) + "\n" ) rf.flush() print(f"[saved] {results_path}") profiler.report() if __name__ == "__main__": main()