Commit c7bfdb1c authored by Watebear's avatar Watebear Committed by GitHub
Browse files

feature: support qwen-image t2i (#217)



* feature: support qwen-image t2i

* pass ci

---------
Co-authored-by: default avatarwushuo1 <wushuo1@sensetime.com>
parent 3c3aa562
{
"seed": 42,
"batchsize": 1,
"_comment": "格式: '宽高比': [width, height]",
"aspect_ratios": {
"1:1": [1328, 1328],
"16:9": [1664, 928],
"9:16": [928, 1664],
"4:3": [1472, 1140],
"3:4": [142, 184]
},
"aspect_ratio": "16:9",
"num_channels_latents": 16,
"vae_scale_factor": 8,
"infer_steps": 50,
"guidance_embeds": false,
"num_images_per_prompt": 1,
"vae_latents_mean": [
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921
],
"vae_latents_std": [
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.916
],
"vae_z_dim": 16,
"feature_caching": "NoCaching"
}
...@@ -7,6 +7,7 @@ from lightx2v.common.ops import * ...@@ -7,6 +7,7 @@ from lightx2v.common.ops import *
from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner # noqa: F401 from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner # noqa: F401
from lightx2v.models.runners.graph_runner import GraphRunner from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401 from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401
from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner # noqa: F401
from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, Wan22MoeAudioRunner, WanAudioRunner # noqa: F401 from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, Wan22MoeAudioRunner, WanAudioRunner # noqa: F401
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401 from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401 from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401
...@@ -51,11 +52,12 @@ def main(): ...@@ -51,11 +52,12 @@ def main():
"wan2.2_moe_audio", "wan2.2_moe_audio",
"wan2.2_audio", "wan2.2_audio",
"wan2.2_moe_distill", "wan2.2_moe_distill",
"qwen_image",
], ],
default="wan2.1", default="wan2.1",
) )
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v", "t2i"], default="t2v")
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True) parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--use_prompt_enhancer", action="store_true") parser.add_argument("--use_prompt_enhancer", action="store_true")
......
import os
import torch
from transformers import Qwen2Tokenizer, Qwen2_5_VLForConditionalGeneration
class Qwen25_VLForConditionalGeneration_TextEncoder:
def __init__(self, config):
self.config = config
self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(os.path.join(config.model_path, "text_encoder")).to(torch.device("cuda")).to(torch.bfloat16)
self.tokenizer = Qwen2Tokenizer.from_pretrained(os.path.join(config.model_path, "tokenizer"))
self.tokenizer_max_length = 1024
self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
self.prompt_template_encode_start_idx = 34
self.device = torch.device("cuda")
self.dtype = torch.bfloat16
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
bool_mask = mask.bool()
valid_lengths = bool_mask.sum(dim=1)
selected = hidden_states[bool_mask]
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
return split_result
def infer(self, text):
template = self.prompt_template_encode
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(e) for e in text]
txt_tokens = self.tokenizer(txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt").to(self.device)
encoder_hidden_states = self.text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask,
output_hidden_states=True,
)
hidden_states = encoder_hidden_states.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list])
prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=self.device)
prompt_embeds_mask = encoder_attention_mask
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, self.config.num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(self.config.batchsize * self.config.num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, self.config.num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(self.config.batchsize * self.config.num_images_per_prompt, seq_len)
return prompt_embeds, prompt_embeds_mask
class QwenImagePostInfer:
def __init__(self, config, norm_out, proj_out):
self.config = config
self.norm_out = norm_out
self.proj_out = proj_out
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def infer(self, hidden_states, temb):
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
return output
from lightx2v.utils.envs import *
class QwenImagePreInfer:
def __init__(self, config, img_in, txt_norm, txt_in, time_text_embed, pos_embed):
self.config = config
self.img_in = img_in
self.txt_norm = txt_norm
self.txt_in = txt_in
self.time_text_embed = time_text_embed
self.pos_embed = pos_embed
self.attention_kwargs = {}
def set_scheduler(self, scheduler):
self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, hidden_states, timestep, guidance, encoder_hidden_states_mask, encoder_hidden_states, img_shapes, txt_seq_lens, attention_kwargs):
hidden_states_0 = hidden_states
hidden_states = self.img_in(hidden_states)
timestep = timestep.to(hidden_states.dtype)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
temb = self.time_text_embed(timestep, hidden_states) if guidance is None else self.time_text_embed(timestep, guidance, hidden_states)
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
return hidden_states, encoder_hidden_states, encoder_hidden_states_mask, (hidden_states_0, temb, image_rotary_emb)
import torch
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
class QwenImageTransformerInfer(BaseTransformerInfer):
def __init__(self, config, blocks):
self.config = config
self.blocks = blocks
self.infer_conditional = True
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def infer_block(self, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, joint_attention_kwargs):
# Get modulation parameters for both streams
img_mod_params = block.img_mod(temb) # [B, 6*dim]
txt_mod_params = block.txt_mod(temb) # [B, 6*dim]
# Split modulation parameters for norm1 and norm2
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
# Process image stream - norm1 + modulation
img_normed = block.img_norm1(hidden_states)
img_modulated, img_gate1 = block._modulate(img_normed, img_mod1)
# Process text stream - norm1 + modulation
txt_normed = block.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = block._modulate(txt_normed, txt_mod1)
# Use QwenAttnProcessor2_0 for joint attention computation
# This directly implements the DoubleStreamLayerMegatron logic:
# 1. Computes QKV for both streams
# 2. Applies QK normalization and RoPE
# 3. Concatenates and runs joint attention
# 4. Splits results back to separate streams
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = block.attn(
hidden_states=img_modulated, # Image stream (will be processed as "sample")
encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
encoder_hidden_states_mask=encoder_hidden_states_mask,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
# QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
img_attn_output, txt_attn_output = attn_output
# Apply attention gates and add residual (like in Megatron)
hidden_states = hidden_states + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
# Process image stream - norm2 + MLP
img_normed2 = block.img_norm2(hidden_states)
img_modulated2, img_gate2 = block._modulate(img_normed2, img_mod2)
img_mlp_output = block.img_mlp(img_modulated2)
hidden_states = hidden_states + img_gate2 * img_mlp_output
# Process text stream - norm2 + MLP
txt_normed2 = block.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = block._modulate(txt_normed2, txt_mod2)
txt_mlp_output = block.txt_mlp(txt_modulated2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
# Clip to prevent overflow for fp16
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return encoder_hidden_states, hidden_states
def infer_calculating(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, attention_kwargs):
for index_block, block in enumerate(self.blocks):
encoder_hidden_states, hidden_states = self.infer_block(
block=block,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=attention_kwargs,
)
return encoder_hidden_states, hidden_states
def infer(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, pre_infer_out, attention_kwargs):
_, temb, image_rotary_emb = pre_infer_out
encoder_hidden_states, hidden_states = self.infer_calculating(hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, attention_kwargs)
return encoder_hidden_states, hidden_states
from typing import Type
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class DefaultLinear(nn.Linear):
def forward(self, input: Tensor) -> Tensor:
return F.linear(input, self.weight, self.bias)
def replace_linear_with_custom(model: nn.Module, CustomLinear: Type[nn.Module]) -> nn.Module:
for name, module in model.named_children():
if isinstance(module, nn.Linear):
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
custom_linear = CustomLinear(in_features=in_features, out_features=out_features, bias=bias)
with torch.no_grad():
custom_linear.weight.copy_(module.weight)
if bias:
custom_linear.bias.copy_(module.bias)
setattr(model, name, custom_linear)
else:
replace_linear_with_custom(module, CustomLinear)
return model
from typing import Type
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
__all__ = ["LayerNorm", "RMSNorm"]
class DefaultLayerNorm(nn.LayerNorm):
def forward(self, input: Tensor) -> Tensor:
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
class DefaultRMSNorm(nn.RMSNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
def replace_layernorm_with_custom(model: nn.Module, CustomLayerNorm: Type[nn.Module]) -> nn.Module:
for name, module in model.named_children():
if isinstance(module, nn.LayerNorm):
normalized_shape = module.normalized_shape
eps = module.eps
elementwise_affine = module.elementwise_affine
custom_layernorm = CustomLayerNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
if elementwise_affine:
with torch.no_grad():
custom_layernorm.weight.copy_(module.weight)
custom_layernorm.bias.copy_(module.bias)
setattr(model, name, custom_layernorm)
else:
replace_layernorm_with_custom(module, CustomLayerNorm)
return model
def replace_rmsnorm_with_custom(model: nn.Module, CustomRMSNorm: Type[nn.Module]) -> nn.Module:
for name, module in model.named_children():
if isinstance(module, nn.RMSNorm):
normalized_shape = module.normalized_shape
eps = getattr(module, "eps", 1e-6)
elementwise_affine = getattr(module, "elementwise_affine", True)
custom_rmsnorm = CustomRMSNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
if elementwise_affine:
with torch.no_grad():
custom_rmsnorm.weight.copy_(module.weight)
if hasattr(module, "bias") and hasattr(custom_rmsnorm, "bias"):
custom_rmsnorm.bias.copy_(module.bias)
setattr(model, name, custom_rmsnorm)
else:
replace_rmsnorm_with_custom(module, CustomRMSNorm)
return model
import json
import os
import torch
from diffusers.models.transformers.transformer_qwenimage import QwenImageTransformer2DModel
from .infer.post_infer import QwenImagePostInfer
from .infer.pre_infer import QwenImagePreInfer
from .infer.transformer_infer import QwenImageTransformerInfer
from .layers.linear import DefaultLinear, replace_linear_with_custom
from .layers.normalization import DefaultLayerNorm, DefaultRMSNorm, replace_layernorm_with_custom, replace_rmsnorm_with_custom
class QwenImageTransformerModel:
def __init__(self, config):
self.config = config
self.transformer = QwenImageTransformer2DModel.from_pretrained(os.path.join(config.model_path, "transformer"))
# repalce linear & normalization
self.transformer = replace_linear_with_custom(self.transformer, DefaultLinear)
self.transformer = replace_layernorm_with_custom(self.transformer, DefaultLayerNorm)
self.transformer = replace_rmsnorm_with_custom(self.transformer, DefaultRMSNorm)
self.transformer.to(torch.device("cuda")).to(torch.bfloat16)
with open(os.path.join(config.model_path, "transformer", "config.json"), "r") as f:
transformer_config = json.load(f)
self.in_channels = transformer_config["in_channels"]
self.attention_kwargs = {}
self._init_infer_class()
self._init_infer()
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def _init_infer_class(self):
if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = QwenImageTransformerInfer
else:
assert NotImplementedError
self.pre_infer_class = QwenImagePreInfer
self.post_infer_class = QwenImagePostInfer
def _init_infer(self):
self.transformer_infer = self.transformer_infer_class(self.config, self.transformer.transformer_blocks)
self.pre_infer = self.pre_infer_class(self.config, self.transformer.img_in, self.transformer.txt_norm, self.transformer.txt_in, self.transformer.time_text_embed, self.transformer.pos_embed)
self.post_infer = self.post_infer_class(self.config, self.transformer.norm_out, self.transformer.proj_out)
@torch.no_grad()
def infer(self, inputs):
t = self.scheduler.timesteps[self.scheduler.step_index]
latents = self.scheduler.latents
timestep = t.expand(latents.shape[0]).to(latents.dtype)
img_shapes = self.scheduler.img_shapes
prompt_embeds = inputs["text_encoder_output"]["prompt_embeds"]
prompt_embeds_mask = inputs["text_encoder_output"]["prompt_embeds_mask"]
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
hidden_states, encoder_hidden_states, encoder_hidden_states_mask, pre_infer_out = self.pre_infer.infer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=self.scheduler.guidance,
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
)
encoder_hidden_states, hidden_states = self.transformer_infer.infer(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
pre_infer_out=pre_infer_out,
attention_kwargs=self.attention_kwargs,
)
noise_pred = self.post_infer.infer(hidden_states, pre_infer_out[1])
self.scheduler.noise_pred = noise_pred
import gc
import torch
from loguru import logger
from lightx2v.models.input_encoders.hf.qwen25.qwen25_vlforconditionalgeneration import Qwen25_VLForConditionalGeneration_TextEncoder
from lightx2v.models.networks.qwen_image.model import QwenImageTransformerModel
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.qwen_image.scheduler import QwenImageScheduler
from lightx2v.models.video_encoders.hf.qwen_image.vae import AutoencoderKLQwenImageVAE
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
@RUNNER_REGISTER("qwen_image")
class QwenImageRunner(DefaultRunner):
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(self, config):
super().__init__(config)
@ProfilingContext("Load models")
def load_model(self):
self.model = self.load_transformer()
self.text_encoders = self.load_text_encoder()
self.vae = self.load_vae()
def load_transformer(self):
model = QwenImageTransformerModel(self.config)
return model
def load_text_encoder(self):
text_encoder = Qwen25_VLForConditionalGeneration_TextEncoder(self.config)
text_encoders = [text_encoder]
return text_encoders
def load_image_encoder(self):
pass
def load_vae(self):
vae = AutoencoderKLQwenImageVAE(self.config)
return vae
def init_modules(self):
logger.info("Initializing runner modules...")
if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
self.load_model()
elif self.config.get("lazy_load", False):
assert self.config.get("cpu_offload", False)
self.run_dit = self._run_dit_local
self.run_vae_decoder = self._run_vae_decoder_local
if self.config["task"] == "t2i":
self.run_input_encoder = self._run_input_encoder_local_i2v
else:
assert NotImplementedError
@ProfilingContext("Run Encoders")
def _run_input_encoder_local_i2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
text_encoder_output = self.run_text_encoder(prompt)
torch.cuda.empty_cache()
gc.collect()
return {
"text_encoder_output": text_encoder_output,
"image_encoder_output": None,
}
def run_text_encoder(self, text):
text_encoder_output = {}
prompt_embeds, prompt_embeds_mask = self.text_encoders[0].infer([text])
text_encoder_output["prompt_embeds"] = prompt_embeds
text_encoder_output["prompt_embeds_mask"] = prompt_embeds_mask
return text_encoder_output
def set_target_shape(self):
self.vae_scale_factor = self.vae.vae_scale_factor if getattr(self, "vae", None) else 8
width, height = self.config.aspect_ratios[self.config.aspect_ratio]
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
num_channels_latents = self.model.in_channels // 4
self.config.target_shape = (self.config.batchsize, 1, num_channels_latents, height, width)
def init_scheduler(self):
scheduler = QwenImageScheduler(self.config)
self.model.set_scheduler(scheduler)
self.model.pre_infer.set_scheduler(scheduler)
self.model.transformer_infer.set_scheduler(scheduler)
self.model.post_infer.set_scheduler(scheduler)
def get_encoder_output_i2v(self):
pass
def run_image_encoder(self):
pass
def run_vae_encoder(self):
pass
@ProfilingContext("Load models")
def load_model(self):
self.model = self.load_transformer()
self.text_encoders = self.load_text_encoder()
self.image_encoder = self.load_image_encoder()
self.vae = self.load_vae()
self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
@ProfilingContext("Run VAE Decoder")
def _run_vae_decoder_local(self, latents, generator):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae()
images = self.vae.decode(latents)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_decoder
torch.cuda.empty_cache()
gc.collect()
return images
def run_pipeline(self, save_image=True):
if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer()
self.inputs = self.run_input_encoder()
self.set_target_shape()
latents, generator = self.run_dit()
images = self.run_vae_decoder(latents, generator)
image = images[0]
image.save(f"{self.config.save_video_path}")
del latents, generator
torch.cuda.empty_cache()
gc.collect()
# Return (images, audio) - audio is None for default runner
return images, None
import inspect
import json
import os
from typing import List, Optional, Union
import numpy as np
import torch
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from diffusers.utils.torch_utils import randn_tensor
from lightx2v.models.schedulers.scheduler import BaseScheduler
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom sigmas schedules. Please check whether you are using the correct scheduler.")
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class QwenImageScheduler(BaseScheduler):
def __init__(self, config):
super().__init__(config)
self.config = config
self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(os.path.join(config.model_path, "scheduler"))
with open(os.path.join(config.model_path, "scheduler", "scheduler_config.json"), "r") as f:
self.scheduler_config = json.load(f)
self.generator = torch.Generator(device="cuda").manual_seed(config.seed)
self.device = torch.device("cuda")
self.dtype = torch.bfloat16
self.guidance_scale = 1.0
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
@staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
return latents
@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(latent_image_id_height * latent_image_id_width, latent_image_id_channels)
return latent_image_ids.to(device=device, dtype=dtype)
def prepare_latents(self):
shape = self.config.target_shape
width, height = self.config.aspect_ratios[self.config.aspect_ratio]
self.vae_scale_factor = self.config.vae_scale_factor if getattr(self, "vae", None) else 8
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
latents = randn_tensor(shape, generator=self.generator, device=self.device, dtype=self.dtype)
latents = self._pack_latents(latents, self.config.batchsize, self.config.num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(self.config.batchsize, height // 2, width // 2, self.device, self.dtype)
self.latents = latents
self.latent_image_ids = latent_image_ids
self.noise_pred = None
def set_timesteps(self):
sigmas = np.linspace(1.0, 1 / self.config.infer_steps, self.config.infer_steps)
image_seq_len = self.latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler_config.get("base_image_seq_len", 256),
self.scheduler_config.get("max_image_seq_len", 4096),
self.scheduler_config.get("base_shift", 0.5),
self.scheduler_config.get("max_shift", 1.15),
)
num_inference_steps = self.config.infer_steps
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
self.device,
sigmas=sigmas,
mu=mu,
)
self.timesteps = timesteps
self.infer_steps = num_inference_steps
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
self.num_warmup_steps = num_warmup_steps
def prepare_guidance(self):
# handle guidance
if self.config.guidance_embeds:
guidance = torch.full([1], self.guidance_scale, device=self.device, dtype=torch.float32)
guidance = guidance.expand(self.latents.shape[0])
else:
guidance = None
self.guidance = guidance
def set_img_shapes(self):
width, height = self.config.aspect_ratios[self.config.aspect_ratio]
self.img_shapes = [(1, height // self.config.vae_scale_factor // 2, width // self.config.vae_scale_factor // 2)] * self.config.batchsize
def prepare(self, image_encoder_output):
self.prepare_latents()
self.prepare_guidance()
self.set_img_shapes()
self.set_timesteps()
def step_post(self):
# compute the previous noisy sample x_t -> x_t-1
t = self.timesteps[self.step_index]
latents = self.scheduler.step(self.noise_pred, t, self.latents, return_dict=False)[0]
self.latents = latents
import json
import os
import torch # type: ignore
from diffusers import AutoencoderKLQwenImage
from diffusers.image_processor import VaeImageProcessor
class AutoencoderKLQwenImageVAE:
def __init__(self, config):
self.config = config
self.model = AutoencoderKLQwenImage.from_pretrained(os.path.join(config.model_path, "vae")).to(torch.device("cuda")).to(torch.bfloat16)
self.image_processor = VaeImageProcessor(vae_scale_factor=config.vae_scale_factor * 2)
with open(os.path.join(config.model_path, "vae", "config.json"), "r") as f:
vae_config = json.load(f)
self.vae_scale_factor = 2 ** len(vae_config["temperal_downsample"]) if "temperal_downsample" in vae_config else 8
self.dtype = torch.bfloat16
@staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
return latents
@torch.no_grad()
def decode(self, latents):
width, height = self.config.aspect_ratios[self.config.aspect_ratio]
latents = self._unpack_latents(latents, height, width, self.config.vae_scale_factor)
latents = latents.to(self.dtype)
latents_mean = torch.tensor(self.config.vae_latents_mean).view(1, self.config.vae_z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = 1.0 / torch.tensor(self.config.vae_latents_std).view(1, self.config.vae_z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
images = self.model.decode(latents, return_dict=False)[0][:, :, 0]
images = self.image_processor.postprocess(images, output_type="pil")
return images
#!/bin/bash
export CUDA_VISIBLE_DEVICES=
# set path and first
export lightx2v_path=
export model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \
--model_cls qwen_image \
--task t2i \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/qwen_image/qwen_image_t2i.json \
--prompt 'A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic compositionUltra HD, 4K, cinematic composition.' \
--save_video_path ${lightx2v_path}/save_results/qwen_image_t2i.png
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment