Commit ecb2107c authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Merge branch 'main' into dev_flf2v

parents d8d70a28 3d8cb02e
{
"seed": 42,
"batchsize": 1,
"_comment": "格式: '宽高比': [width, height]",
"aspect_ratios": {
"1:1": [1328, 1328],
"16:9": [1664, 928],
"9:16": [928, 1664],
"4:3": [1472, 1140],
"3:4": [142, 184]
},
"aspect_ratio": "16:9",
"num_channels_latents": 16,
"vae_scale_factor": 8,
"infer_steps": 50,
"guidance_embeds": false,
"num_images_per_prompt": 1,
"vae_latents_mean": [
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921
],
"vae_latents_std": [
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.916
],
"vae_z_dim": 16,
"feature_caching": "NoCaching"
}
...@@ -7,6 +7,7 @@ from lightx2v.common.ops import * ...@@ -7,6 +7,7 @@ from lightx2v.common.ops import *
from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner # noqa: F401 from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner # noqa: F401
from lightx2v.models.runners.graph_runner import GraphRunner from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401 from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401
from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner # noqa: F401
from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, Wan22MoeAudioRunner, WanAudioRunner # noqa: F401 from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, Wan22MoeAudioRunner, WanAudioRunner # noqa: F401
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401 from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401 from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401
...@@ -51,11 +52,12 @@ def main(): ...@@ -51,11 +52,12 @@ def main():
"wan2.2_moe_audio", "wan2.2_moe_audio",
"wan2.2_audio", "wan2.2_audio",
"wan2.2_moe_distill", "wan2.2_moe_distill",
"qwen_image",
], ],
default="wan2.1", default="wan2.1",
) )
parser.add_argument("--task", type=str, choices=["t2v", "i2v", "flf2v"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v", "t2i", "flf2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True) parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--use_prompt_enhancer", action="store_true") parser.add_argument("--use_prompt_enhancer", action="store_true")
......
import os
import torch
from transformers import Qwen2Tokenizer, Qwen2_5_VLForConditionalGeneration
class Qwen25_VLForConditionalGeneration_TextEncoder:
def __init__(self, config):
self.config = config
self.text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(os.path.join(config.model_path, "text_encoder")).to(torch.device("cuda")).to(torch.bfloat16)
self.tokenizer = Qwen2Tokenizer.from_pretrained(os.path.join(config.model_path, "tokenizer"))
self.tokenizer_max_length = 1024
self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
self.prompt_template_encode_start_idx = 34
self.device = torch.device("cuda")
self.dtype = torch.bfloat16
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
bool_mask = mask.bool()
valid_lengths = bool_mask.sum(dim=1)
selected = hidden_states[bool_mask]
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
return split_result
def infer(self, text):
template = self.prompt_template_encode
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(e) for e in text]
txt_tokens = self.tokenizer(txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt").to(self.device)
encoder_hidden_states = self.text_encoder(
input_ids=txt_tokens.input_ids,
attention_mask=txt_tokens.attention_mask,
output_hidden_states=True,
)
hidden_states = encoder_hidden_states.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list])
prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=self.device)
prompt_embeds_mask = encoder_attention_mask
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, self.config.num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(self.config.batchsize * self.config.num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, self.config.num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(self.config.batchsize * self.config.num_images_per_prompt, seq_len)
return prompt_embeds, prompt_embeds_mask
class QwenImagePostInfer:
def __init__(self, config, norm_out, proj_out):
self.config = config
self.norm_out = norm_out
self.proj_out = proj_out
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def infer(self, hidden_states, temb):
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
return output
from lightx2v.utils.envs import *
class QwenImagePreInfer:
def __init__(self, config, img_in, txt_norm, txt_in, time_text_embed, pos_embed):
self.config = config
self.img_in = img_in
self.txt_norm = txt_norm
self.txt_in = txt_in
self.time_text_embed = time_text_embed
self.pos_embed = pos_embed
self.attention_kwargs = {}
def set_scheduler(self, scheduler):
self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, hidden_states, timestep, guidance, encoder_hidden_states_mask, encoder_hidden_states, img_shapes, txt_seq_lens, attention_kwargs):
hidden_states_0 = hidden_states
hidden_states = self.img_in(hidden_states)
timestep = timestep.to(hidden_states.dtype)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
temb = self.time_text_embed(timestep, hidden_states) if guidance is None else self.time_text_embed(timestep, guidance, hidden_states)
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
return hidden_states, encoder_hidden_states, encoder_hidden_states_mask, (hidden_states_0, temb, image_rotary_emb)
import torch
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
class QwenImageTransformerInfer(BaseTransformerInfer):
def __init__(self, config, blocks):
self.config = config
self.blocks = blocks
self.infer_conditional = True
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def infer_block(self, block, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, joint_attention_kwargs):
# Get modulation parameters for both streams
img_mod_params = block.img_mod(temb) # [B, 6*dim]
txt_mod_params = block.txt_mod(temb) # [B, 6*dim]
# Split modulation parameters for norm1 and norm2
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
# Process image stream - norm1 + modulation
img_normed = block.img_norm1(hidden_states)
img_modulated, img_gate1 = block._modulate(img_normed, img_mod1)
# Process text stream - norm1 + modulation
txt_normed = block.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = block._modulate(txt_normed, txt_mod1)
# Use QwenAttnProcessor2_0 for joint attention computation
# This directly implements the DoubleStreamLayerMegatron logic:
# 1. Computes QKV for both streams
# 2. Applies QK normalization and RoPE
# 3. Concatenates and runs joint attention
# 4. Splits results back to separate streams
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = block.attn(
hidden_states=img_modulated, # Image stream (will be processed as "sample")
encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
encoder_hidden_states_mask=encoder_hidden_states_mask,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
# QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
img_attn_output, txt_attn_output = attn_output
# Apply attention gates and add residual (like in Megatron)
hidden_states = hidden_states + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
# Process image stream - norm2 + MLP
img_normed2 = block.img_norm2(hidden_states)
img_modulated2, img_gate2 = block._modulate(img_normed2, img_mod2)
img_mlp_output = block.img_mlp(img_modulated2)
hidden_states = hidden_states + img_gate2 * img_mlp_output
# Process text stream - norm2 + MLP
txt_normed2 = block.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = block._modulate(txt_normed2, txt_mod2)
txt_mlp_output = block.txt_mlp(txt_modulated2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
# Clip to prevent overflow for fp16
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return encoder_hidden_states, hidden_states
def infer_calculating(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, attention_kwargs):
for index_block, block in enumerate(self.blocks):
encoder_hidden_states, hidden_states = self.infer_block(
block=block,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=attention_kwargs,
)
return encoder_hidden_states, hidden_states
def infer(self, hidden_states, encoder_hidden_states, encoder_hidden_states_mask, pre_infer_out, attention_kwargs):
_, temb, image_rotary_emb = pre_infer_out
encoder_hidden_states, hidden_states = self.infer_calculating(hidden_states, encoder_hidden_states, encoder_hidden_states_mask, temb, image_rotary_emb, attention_kwargs)
return encoder_hidden_states, hidden_states
from typing import Type
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class DefaultLinear(nn.Linear):
def forward(self, input: Tensor) -> Tensor:
return F.linear(input, self.weight, self.bias)
def replace_linear_with_custom(model: nn.Module, CustomLinear: Type[nn.Module]) -> nn.Module:
for name, module in model.named_children():
if isinstance(module, nn.Linear):
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
custom_linear = CustomLinear(in_features=in_features, out_features=out_features, bias=bias)
with torch.no_grad():
custom_linear.weight.copy_(module.weight)
if bias:
custom_linear.bias.copy_(module.bias)
setattr(model, name, custom_linear)
else:
replace_linear_with_custom(module, CustomLinear)
return model
from typing import Type
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
__all__ = ["LayerNorm", "RMSNorm"]
class DefaultLayerNorm(nn.LayerNorm):
def forward(self, input: Tensor) -> Tensor:
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
class DefaultRMSNorm(nn.RMSNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
def replace_layernorm_with_custom(model: nn.Module, CustomLayerNorm: Type[nn.Module]) -> nn.Module:
for name, module in model.named_children():
if isinstance(module, nn.LayerNorm):
normalized_shape = module.normalized_shape
eps = module.eps
elementwise_affine = module.elementwise_affine
custom_layernorm = CustomLayerNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
if elementwise_affine:
with torch.no_grad():
custom_layernorm.weight.copy_(module.weight)
custom_layernorm.bias.copy_(module.bias)
setattr(model, name, custom_layernorm)
else:
replace_layernorm_with_custom(module, CustomLayerNorm)
return model
def replace_rmsnorm_with_custom(model: nn.Module, CustomRMSNorm: Type[nn.Module]) -> nn.Module:
for name, module in model.named_children():
if isinstance(module, nn.RMSNorm):
normalized_shape = module.normalized_shape
eps = getattr(module, "eps", 1e-6)
elementwise_affine = getattr(module, "elementwise_affine", True)
custom_rmsnorm = CustomRMSNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
if elementwise_affine:
with torch.no_grad():
custom_rmsnorm.weight.copy_(module.weight)
if hasattr(module, "bias") and hasattr(custom_rmsnorm, "bias"):
custom_rmsnorm.bias.copy_(module.bias)
setattr(model, name, custom_rmsnorm)
else:
replace_rmsnorm_with_custom(module, CustomRMSNorm)
return model
import json
import os
import torch
try:
from diffusers.models.transformers.transformer_qwenimage import QwenImageTransformer2DModel
except ImportError:
QwenImageTransformer2DModel = None
from .infer.post_infer import QwenImagePostInfer
from .infer.pre_infer import QwenImagePreInfer
from .infer.transformer_infer import QwenImageTransformerInfer
from .layers.linear import DefaultLinear, replace_linear_with_custom
from .layers.normalization import DefaultLayerNorm, DefaultRMSNorm, replace_layernorm_with_custom, replace_rmsnorm_with_custom
class QwenImageTransformerModel:
def __init__(self, config):
self.config = config
self.transformer = QwenImageTransformer2DModel.from_pretrained(os.path.join(config.model_path, "transformer"))
# repalce linear & normalization
self.transformer = replace_linear_with_custom(self.transformer, DefaultLinear)
self.transformer = replace_layernorm_with_custom(self.transformer, DefaultLayerNorm)
self.transformer = replace_rmsnorm_with_custom(self.transformer, DefaultRMSNorm)
self.transformer.to(torch.device("cuda")).to(torch.bfloat16)
with open(os.path.join(config.model_path, "transformer", "config.json"), "r") as f:
transformer_config = json.load(f)
self.in_channels = transformer_config["in_channels"]
self.attention_kwargs = {}
self._init_infer_class()
self._init_infer()
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def _init_infer_class(self):
if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = QwenImageTransformerInfer
else:
assert NotImplementedError
self.pre_infer_class = QwenImagePreInfer
self.post_infer_class = QwenImagePostInfer
def _init_infer(self):
self.transformer_infer = self.transformer_infer_class(self.config, self.transformer.transformer_blocks)
self.pre_infer = self.pre_infer_class(self.config, self.transformer.img_in, self.transformer.txt_norm, self.transformer.txt_in, self.transformer.time_text_embed, self.transformer.pos_embed)
self.post_infer = self.post_infer_class(self.config, self.transformer.norm_out, self.transformer.proj_out)
@torch.no_grad()
def infer(self, inputs):
t = self.scheduler.timesteps[self.scheduler.step_index]
latents = self.scheduler.latents
timestep = t.expand(latents.shape[0]).to(latents.dtype)
img_shapes = self.scheduler.img_shapes
prompt_embeds = inputs["text_encoder_output"]["prompt_embeds"]
prompt_embeds_mask = inputs["text_encoder_output"]["prompt_embeds_mask"]
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
hidden_states, encoder_hidden_states, encoder_hidden_states_mask, pre_infer_out = self.pre_infer.infer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=self.scheduler.guidance,
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
attention_kwargs=self.attention_kwargs,
)
encoder_hidden_states, hidden_states = self.transformer_infer.infer(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
pre_infer_out=pre_infer_out,
attention_kwargs=self.attention_kwargs,
)
noise_pred = self.post_infer.infer(hidden_states, pre_infer_out[1])
self.scheduler.noise_pred = noise_pred
...@@ -54,7 +54,7 @@ class WanCausVidModel(WanModel): ...@@ -54,7 +54,7 @@ class WanCausVidModel(WanModel):
self.pre_weight.to_cuda() self.pre_weight.to_cuda()
self.transformer_weights.post_weights_to_cuda() self.transformer_weights.post_weights_to_cuda()
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=True, kv_start=kv_start, kv_end=kv_end) embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, kv_start=kv_start, kv_end=kv_end)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out, kv_start, kv_end) x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out, kv_start, kv_end)
self.scheduler.noise_pred = self.post_infer.infer(x, embed, grid_sizes)[0] self.scheduler.noise_pred = self.post_infer.infer(x, embed, grid_sizes)[0]
......
...@@ -35,7 +35,7 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -35,7 +35,7 @@ class WanAudioPreInfer(WanPreInfer):
else: else:
self.sp_size = 1 self.sp_size = 1
def infer(self, weights, inputs, positive): def infer(self, weights, inputs):
prev_latents = inputs["previmg_encoder_output"]["prev_latents"] prev_latents = inputs["previmg_encoder_output"]["prev_latents"]
if self.config.model_cls == "wan2.2_audio": if self.config.model_cls == "wan2.2_audio":
hidden_states = self.scheduler.latents hidden_states = self.scheduler.latents
...@@ -71,7 +71,7 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -71,7 +71,7 @@ class WanAudioPreInfer(WanPreInfer):
audio_dit_blocks.append(inputs["audio_adapter_pipe"](**audio_model_input)) audio_dit_blocks.append(inputs["audio_adapter_pipe"](**audio_model_input))
# audio_dit_blocks = None##Debug Drop Audio # audio_dit_blocks = None##Debug Drop Audio
if positive: if self.scheduler.infer_condition:
context = inputs["text_encoder_output"]["context"] context = inputs["text_encoder_output"]["context"]
else: else:
context = inputs["text_encoder_output"]["context_null"] context = inputs["text_encoder_output"]["context_null"]
...@@ -104,17 +104,34 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -104,17 +104,34 @@ class WanAudioPreInfer(WanPreInfer):
y = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in y] y = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in y]
# y_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in y]) # y_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in y])
y = [u.flatten(2).transpose(1, 2).squeeze(0) for u in y] y = [u.flatten(2).transpose(1, 2).squeeze(0) for u in y]
ref_seq_lens = torch.tensor([u.size(0) for u in y], dtype=torch.long)
x = [torch.cat([a, b], dim=0) for a, b in zip(x, y)] x = [torch.cat([a, b], dim=0) for a, b in zip(x, y)]
x = torch.stack(x, dim=0) x = torch.stack(x, dim=0)
seq_len = x[0].size(0)
if self.config.model_cls == "wan2.2_audio":
bt = t.size(0)
ref_seq_len = ref_seq_lens[0].item()
t = torch.cat(
[
t,
torch.zeros(
(1, ref_seq_len),
dtype=t.dtype,
device=t.device,
),
],
dim=1,
)
embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten()) embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
# embed = weights.time_embedding_0.apply(embed)
if self.sensitive_layer_dtype != self.infer_dtype: if self.sensitive_layer_dtype != self.infer_dtype:
embed = weights.time_embedding_0.apply(embed.to(self.sensitive_layer_dtype)) embed = weights.time_embedding_0.apply(embed.to(self.sensitive_layer_dtype))
else: else:
embed = weights.time_embedding_0.apply(embed) embed = weights.time_embedding_0.apply(embed)
embed = torch.nn.functional.silu(embed) embed = torch.nn.functional.silu(embed)
embed = weights.time_embedding_2.apply(embed) embed = weights.time_embedding_2.apply(embed)
embed0 = torch.nn.functional.silu(embed) embed0 = torch.nn.functional.silu(embed)
embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim)) embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim))
......
...@@ -24,7 +24,6 @@ class WanTransformerInferCaching(WanTransformerInfer): ...@@ -24,7 +24,6 @@ class WanTransformerInferCaching(WanTransformerInfer):
class WanTransformerInferTeaCaching(WanTransformerInferCaching): class WanTransformerInferTeaCaching(WanTransformerInferCaching):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.cnt = 0
self.teacache_thresh = config.teacache_thresh self.teacache_thresh = config.teacache_thresh
self.accumulated_rel_l1_distance_even = 0 self.accumulated_rel_l1_distance_even = 0
self.previous_e0_even = None self.previous_e0_even = None
...@@ -35,22 +34,23 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching): ...@@ -35,22 +34,23 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
self.use_ret_steps = config.use_ret_steps self.use_ret_steps = config.use_ret_steps
if self.use_ret_steps: if self.use_ret_steps:
self.coefficients = self.config.coefficients[0] self.coefficients = self.config.coefficients[0]
self.ret_steps = 5 * 2 self.ret_steps = 5
self.cutoff_steps = self.config.infer_steps * 2 self.cutoff_steps = self.config.infer_steps
else: else:
self.coefficients = self.config.coefficients[1] self.coefficients = self.config.coefficients[1]
self.ret_steps = 1 * 2 self.ret_steps = 1
self.cutoff_steps = self.config.infer_steps * 2 - 2 self.cutoff_steps = self.config.infer_steps - 1
# calculate should_calc # calculate should_calc
@torch.no_grad()
def calculate_should_calc(self, embed, embed0): def calculate_should_calc(self, embed, embed0):
# 1. timestep embedding # 1. timestep embedding
modulated_inp = embed0 if self.use_ret_steps else embed modulated_inp = embed0 if self.use_ret_steps else embed
# 2. L1 calculate # 2. L1 calculate
should_calc = False should_calc = False
if self.infer_conditional: if self.scheduler.infer_condition:
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: if self.scheduler.step_index < self.ret_steps or self.scheduler.step_index >= self.cutoff_steps:
should_calc = True should_calc = True
self.accumulated_rel_l1_distance_even = 0 self.accumulated_rel_l1_distance_even = 0
else: else:
...@@ -66,7 +66,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching): ...@@ -66,7 +66,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
self.previous_e0_even = self.previous_e0_even.cpu() self.previous_e0_even = self.previous_e0_even.cpu()
else: else:
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: if self.scheduler.step_index < self.ret_steps or self.scheduler.step_index >= self.cutoff_steps:
should_calc = True should_calc = True
self.accumulated_rel_l1_distance_odd = 0 self.accumulated_rel_l1_distance_odd = 0
else: else:
...@@ -95,35 +95,30 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching): ...@@ -95,35 +95,30 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
# 3. return the judgement # 3. return the judgement
return should_calc return should_calc
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer_main_blocks(self, weights, pre_infer_out):
if self.infer_conditional: if self.scheduler.infer_condition:
index = self.scheduler.step_index index = self.scheduler.step_index
caching_records = self.scheduler.caching_records caching_records = self.scheduler.caching_records
if index <= self.scheduler.infer_steps - 1: if index <= self.scheduler.infer_steps - 1:
should_calc = self.calculate_should_calc(embed, embed0) should_calc = self.calculate_should_calc(pre_infer_out.embed, pre_infer_out.embed0)
self.scheduler.caching_records[index] = should_calc self.scheduler.caching_records[index] = should_calc
if caching_records[index] or self.must_calc(index): if caching_records[index] or self.must_calc(index):
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) x = self.infer_calculating(weights, pre_infer_out)
else: else:
x = self.infer_using_cache(x) x = self.infer_using_cache(pre_infer_out.x)
else: else:
index = self.scheduler.step_index index = self.scheduler.step_index
caching_records_2 = self.scheduler.caching_records_2 caching_records_2 = self.scheduler.caching_records_2
if index <= self.scheduler.infer_steps - 1: if index <= self.scheduler.infer_steps - 1:
should_calc = self.calculate_should_calc(embed, embed0) should_calc = self.calculate_should_calc(pre_infer_out.embed, pre_infer_out.embed0)
self.scheduler.caching_records_2[index] = should_calc self.scheduler.caching_records_2[index] = should_calc
if caching_records_2[index] or self.must_calc(index): if caching_records_2[index] or self.must_calc(index):
x = self.infer_calculating(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) x = self.infer_calculating(weights, pre_infer_out)
else: else:
x = self.infer_using_cache(x) x = self.infer_using_cache(pre_infer_out.x)
if self.config.enable_cfg:
self.switch_status()
self.cnt += 1
if self.clean_cuda_cache: if self.clean_cuda_cache:
del grid_sizes, embed, embed0, seq_lens, freqs, context del grid_sizes, embed, embed0, seq_lens, freqs, context
...@@ -131,20 +126,11 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching): ...@@ -131,20 +126,11 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
return x return x
def infer_calculating(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer_calculating(self, weights, pre_infer_out):
ori_x = x.clone() ori_x = pre_infer_out.x.clone()
x = super().infer( x = super().infer_main_blocks(weights, pre_infer_out)
weights, if self.scheduler.infer_condition:
grid_sizes,
embed,
x,
embed0,
seq_lens,
freqs,
context,
)
if self.infer_conditional:
self.previous_residual_even = x - ori_x self.previous_residual_even = x - ori_x
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
self.previous_residual_even = self.previous_residual_even.cpu() self.previous_residual_even = self.previous_residual_even.cpu()
...@@ -161,7 +147,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching): ...@@ -161,7 +147,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
return x return x
def infer_using_cache(self, x): def infer_using_cache(self, x):
if self.infer_conditional: if self.scheduler.infer_condition:
x.add_(self.previous_residual_even.cuda()) x.add_(self.previous_residual_even.cuda())
else: else:
x.add_(self.previous_residual_odd.cuda()) x.add_(self.previous_residual_odd.cuda())
......
...@@ -33,7 +33,7 @@ class WanPreInfer: ...@@ -33,7 +33,7 @@ class WanPreInfer:
self.scheduler = scheduler self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE()) @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, inputs, positive, kv_start=0, kv_end=0): def infer(self, weights, inputs, kv_start=0, kv_end=0):
x = self.scheduler.latents x = self.scheduler.latents
if self.scheduler.flag_df: if self.scheduler.flag_df:
...@@ -45,7 +45,7 @@ class WanPreInfer: ...@@ -45,7 +45,7 @@ class WanPreInfer:
if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v": if self.config["model_cls"] == "wan2.2" and self.config["task"] == "i2v":
t = (self.scheduler.mask[0][:, ::2, ::2] * t).flatten() t = (self.scheduler.mask[0][:, ::2, ::2] * t).flatten()
if positive: if self.scheduler.infer_condition:
context = inputs["text_encoder_output"]["context"] context = inputs["text_encoder_output"]["context"]
else: else:
context = inputs["text_encoder_output"]["context_null"] context = inputs["text_encoder_output"]["context_null"]
......
...@@ -78,11 +78,6 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -78,11 +78,6 @@ class WanTransformerInfer(BaseTransformerInfer):
else: else:
self.infer_func = self._infer_without_offload self.infer_func = self._infer_without_offload
self.infer_conditional = True
def switch_status(self):
self.infer_conditional = not self.infer_conditional
def _calculate_q_k_len(self, q, k_lens): def _calculate_q_k_len(self, q, k_lens):
q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device) q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device)
cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32) cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
...@@ -104,6 +99,10 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -104,6 +99,10 @@ class WanTransformerInfer(BaseTransformerInfer):
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE()) @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, pre_infer_out): def infer(self, weights, pre_infer_out):
x = self.infer_main_blocks(weights, pre_infer_out)
return self.infer_post_blocks(weights, x, pre_infer_out.embed)
def infer_main_blocks(self, weights, pre_infer_out):
x = self.infer_func( x = self.infer_func(
weights, weights,
pre_infer_out.grid_sizes, pre_infer_out.grid_sizes,
...@@ -115,9 +114,9 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -115,9 +114,9 @@ class WanTransformerInfer(BaseTransformerInfer):
pre_infer_out.context, pre_infer_out.context,
pre_infer_out.audio_dit_blocks, pre_infer_out.audio_dit_blocks,
) )
return self._infer_post_blocks(weights, x, pre_infer_out.embed) return x
def _infer_post_blocks(self, weights, x, e): def infer_post_blocks(self, weights, x, e):
if e.dim() == 2: if e.dim() == 2:
modulation = weights.head_modulation.tensor # 1, 2, dim modulation = weights.head_modulation.tensor # 1, 2, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1) e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
......
...@@ -53,29 +53,33 @@ class WanLoraWrapper: ...@@ -53,29 +53,33 @@ class WanLoraWrapper:
def _apply_lora_weights(self, weight_dict, lora_weights, alpha): def _apply_lora_weights(self, weight_dict, lora_weights, alpha):
lora_pairs = {} lora_pairs = {}
lora_diffs = {} lora_diffs = {}
prefix = "diffusion_model."
def try_lora_pair(key, suffix_a, suffix_b, target_suffix): def try_lora_pair(key, prefix, suffix_a, suffix_b, target_suffix):
if key.endswith(suffix_a): if key.endswith(suffix_a):
base_name = key[len(prefix) :].replace(suffix_a, target_suffix) base_name = key[len(prefix) :].replace(suffix_a, target_suffix)
pair_key = key.replace(suffix_a, suffix_b) pair_key = key.replace(suffix_a, suffix_b)
if pair_key in lora_weights: if pair_key in lora_weights:
lora_pairs[base_name] = (key, pair_key) lora_pairs[base_name] = (key, pair_key)
def try_lora_diff(key, suffix, target_suffix): def try_lora_diff(key, prefix, suffix, target_suffix):
if key.endswith(suffix): if key.endswith(suffix):
base_name = key[len(prefix) :].replace(suffix, target_suffix) base_name = key[len(prefix) :].replace(suffix, target_suffix)
lora_diffs[base_name] = key lora_diffs[base_name] = key
for key in lora_weights.keys(): prefixs = [
if not key.startswith(prefix): "", # empty prefix
continue "diffusion_model.",
]
try_lora_pair(key, "lora_A.weight", "lora_B.weight", "weight") for prefix in prefixs:
try_lora_pair(key, "lora_down.weight", "lora_up.weight", "weight") for key in lora_weights.keys():
try_lora_diff(key, "diff", "weight") if not key.startswith(prefix):
try_lora_diff(key, "diff_b", "bias") continue
try_lora_diff(key, "diff_m", "modulation")
try_lora_pair(key, prefix, "lora_A.weight", "lora_B.weight", "weight")
try_lora_pair(key, prefix, "lora_down.weight", "lora_up.weight", "weight")
try_lora_diff(key, prefix, "diff", "weight")
try_lora_diff(key, prefix, "diff_b", "bias")
try_lora_diff(key, prefix, "diff_m", "modulation")
applied_count = 0 applied_count = 0
for name, param in weight_dict.items(): for name, param in weight_dict.items():
......
...@@ -329,9 +329,9 @@ class WanModel: ...@@ -329,9 +329,9 @@ class WanModel:
cfg_p_rank = dist.get_rank(cfg_p_group) cfg_p_rank = dist.get_rank(cfg_p_group)
if cfg_p_rank == 0: if cfg_p_rank == 0:
noise_pred = self._infer_cond_uncond(inputs, positive=True) noise_pred = self._infer_cond_uncond(inputs, infer_condition=True)
else: else:
noise_pred = self._infer_cond_uncond(inputs, positive=False) noise_pred = self._infer_cond_uncond(inputs, infer_condition=False)
noise_pred_list = [torch.zeros_like(noise_pred) for _ in range(2)] noise_pred_list = [torch.zeros_like(noise_pred) for _ in range(2)]
dist.all_gather(noise_pred_list, noise_pred, group=cfg_p_group) dist.all_gather(noise_pred_list, noise_pred, group=cfg_p_group)
...@@ -339,13 +339,13 @@ class WanModel: ...@@ -339,13 +339,13 @@ class WanModel:
noise_pred_uncond = noise_pred_list[1] # cfg_p_rank == 1 noise_pred_uncond = noise_pred_list[1] # cfg_p_rank == 1
else: else:
# ==================== CFG Processing ==================== # ==================== CFG Processing ====================
noise_pred_cond = self._infer_cond_uncond(inputs, positive=True) noise_pred_cond = self._infer_cond_uncond(inputs, infer_condition=True)
noise_pred_uncond = self._infer_cond_uncond(inputs, positive=False) noise_pred_uncond = self._infer_cond_uncond(inputs, infer_condition=False)
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond) self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
else: else:
# ==================== No CFG ==================== # ==================== No CFG ====================
self.scheduler.noise_pred = self._infer_cond_uncond(inputs, positive=True) self.scheduler.noise_pred = self._infer_cond_uncond(inputs, infer_condition=True)
if self.cpu_offload: if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1: if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1:
...@@ -355,8 +355,10 @@ class WanModel: ...@@ -355,8 +355,10 @@ class WanModel:
self.transformer_weights.post_weights_to_cpu() self.transformer_weights.post_weights_to_cpu()
@torch.no_grad() @torch.no_grad()
def _infer_cond_uncond(self, inputs, positive=True): def _infer_cond_uncond(self, inputs, infer_condition=True):
pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=positive) self.scheduler.infer_condition = infer_condition
pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs)
if self.config["seq_parallel"]: if self.config["seq_parallel"]:
pre_infer_out = self._seq_parallel_pre_process(pre_infer_out) pre_infer_out = self._seq_parallel_pre_process(pre_infer_out)
......
import gc
import torch
from loguru import logger
from lightx2v.models.input_encoders.hf.qwen25.qwen25_vlforconditionalgeneration import Qwen25_VLForConditionalGeneration_TextEncoder
from lightx2v.models.networks.qwen_image.model import QwenImageTransformerModel
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.qwen_image.scheduler import QwenImageScheduler
from lightx2v.models.video_encoders.hf.qwen_image.vae import AutoencoderKLQwenImageVAE
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
@RUNNER_REGISTER("qwen_image")
class QwenImageRunner(DefaultRunner):
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(self, config):
super().__init__(config)
@ProfilingContext("Load models")
def load_model(self):
self.model = self.load_transformer()
self.text_encoders = self.load_text_encoder()
self.vae = self.load_vae()
def load_transformer(self):
model = QwenImageTransformerModel(self.config)
return model
def load_text_encoder(self):
text_encoder = Qwen25_VLForConditionalGeneration_TextEncoder(self.config)
text_encoders = [text_encoder]
return text_encoders
def load_image_encoder(self):
pass
def load_vae(self):
vae = AutoencoderKLQwenImageVAE(self.config)
return vae
def init_modules(self):
logger.info("Initializing runner modules...")
if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
self.load_model()
elif self.config.get("lazy_load", False):
assert self.config.get("cpu_offload", False)
self.run_dit = self._run_dit_local
self.run_vae_decoder = self._run_vae_decoder_local
if self.config["task"] == "t2i":
self.run_input_encoder = self._run_input_encoder_local_i2v
else:
assert NotImplementedError
@ProfilingContext("Run Encoders")
def _run_input_encoder_local_i2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
text_encoder_output = self.run_text_encoder(prompt)
torch.cuda.empty_cache()
gc.collect()
return {
"text_encoder_output": text_encoder_output,
"image_encoder_output": None,
}
def run_text_encoder(self, text):
text_encoder_output = {}
prompt_embeds, prompt_embeds_mask = self.text_encoders[0].infer([text])
text_encoder_output["prompt_embeds"] = prompt_embeds
text_encoder_output["prompt_embeds_mask"] = prompt_embeds_mask
return text_encoder_output
def set_target_shape(self):
self.vae_scale_factor = self.vae.vae_scale_factor if getattr(self, "vae", None) else 8
width, height = self.config.aspect_ratios[self.config.aspect_ratio]
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
num_channels_latents = self.model.in_channels // 4
self.config.target_shape = (self.config.batchsize, 1, num_channels_latents, height, width)
def init_scheduler(self):
scheduler = QwenImageScheduler(self.config)
self.model.set_scheduler(scheduler)
self.model.pre_infer.set_scheduler(scheduler)
self.model.transformer_infer.set_scheduler(scheduler)
self.model.post_infer.set_scheduler(scheduler)
def get_encoder_output_i2v(self):
pass
def run_image_encoder(self):
pass
def run_vae_encoder(self):
pass
@ProfilingContext("Load models")
def load_model(self):
self.model = self.load_transformer()
self.text_encoders = self.load_text_encoder()
self.image_encoder = self.load_image_encoder()
self.vae = self.load_vae()
self.vfi_model = self.load_vfi_model() if "video_frame_interpolation" in self.config else None
@ProfilingContext("Run VAE Decoder")
def _run_vae_decoder_local(self, latents, generator):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae()
images = self.vae.decode(latents)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_decoder
torch.cuda.empty_cache()
gc.collect()
return images
def run_pipeline(self, save_image=True):
if self.config["use_prompt_enhancer"]:
self.config["prompt_enhanced"] = self.post_prompt_enhancer()
self.inputs = self.run_input_encoder()
self.set_target_shape()
latents, generator = self.run_dit()
images = self.run_vae_decoder(latents, generator)
image = images[0]
image.save(f"{self.config.save_video_path}")
del latents, generator
torch.cuda.empty_cache()
gc.collect()
# Return (images, audio) - audio is None for default runner
return images, None
import inspect
import json
import os
from typing import List, Optional, Union
import numpy as np
import torch
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from diffusers.utils.torch_utils import randn_tensor
from lightx2v.models.schedulers.scheduler import BaseScheduler
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom sigmas schedules. Please check whether you are using the correct scheduler.")
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class QwenImageScheduler(BaseScheduler):
def __init__(self, config):
super().__init__(config)
self.config = config
self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(os.path.join(config.model_path, "scheduler"))
with open(os.path.join(config.model_path, "scheduler", "scheduler_config.json"), "r") as f:
self.scheduler_config = json.load(f)
self.generator = torch.Generator(device="cuda").manual_seed(config.seed)
self.device = torch.device("cuda")
self.dtype = torch.bfloat16
self.guidance_scale = 1.0
@staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
@staticmethod
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
return latents
@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(latent_image_id_height * latent_image_id_width, latent_image_id_channels)
return latent_image_ids.to(device=device, dtype=dtype)
def prepare_latents(self):
shape = self.config.target_shape
width, height = self.config.aspect_ratios[self.config.aspect_ratio]
self.vae_scale_factor = self.config.vae_scale_factor if getattr(self, "vae", None) else 8
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
latents = randn_tensor(shape, generator=self.generator, device=self.device, dtype=self.dtype)
latents = self._pack_latents(latents, self.config.batchsize, self.config.num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(self.config.batchsize, height // 2, width // 2, self.device, self.dtype)
self.latents = latents
self.latent_image_ids = latent_image_ids
self.noise_pred = None
def set_timesteps(self):
sigmas = np.linspace(1.0, 1 / self.config.infer_steps, self.config.infer_steps)
image_seq_len = self.latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler_config.get("base_image_seq_len", 256),
self.scheduler_config.get("max_image_seq_len", 4096),
self.scheduler_config.get("base_shift", 0.5),
self.scheduler_config.get("max_shift", 1.15),
)
num_inference_steps = self.config.infer_steps
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
self.device,
sigmas=sigmas,
mu=mu,
)
self.timesteps = timesteps
self.infer_steps = num_inference_steps
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
self.num_warmup_steps = num_warmup_steps
def prepare_guidance(self):
# handle guidance
if self.config.guidance_embeds:
guidance = torch.full([1], self.guidance_scale, device=self.device, dtype=torch.float32)
guidance = guidance.expand(self.latents.shape[0])
else:
guidance = None
self.guidance = guidance
def set_img_shapes(self):
width, height = self.config.aspect_ratios[self.config.aspect_ratio]
self.img_shapes = [(1, height // self.config.vae_scale_factor // 2, width // self.config.vae_scale_factor // 2)] * self.config.batchsize
def prepare(self, image_encoder_output):
self.prepare_latents()
self.prepare_guidance()
self.set_img_shapes()
self.set_timesteps()
def step_post(self):
# compute the previous noisy sample x_t -> x_t-1
t = self.timesteps[self.step_index]
latents = self.scheduler.step(self.noise_pred, t, self.latents, return_dict=False)[0]
self.latents = latents
...@@ -10,6 +10,7 @@ class BaseScheduler: ...@@ -10,6 +10,7 @@ class BaseScheduler:
self.caching_records = [True] * config.infer_steps self.caching_records = [True] * config.infer_steps
self.flag_df = False self.flag_df = False
self.transformer_infer = None self.transformer_infer = None
self.infer_condition = True # cfg status
def step_pre(self, step_index): def step_pre(self, step_index):
self.step_index = step_index self.step_index = step_index
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment