Unverified Commit f21da849 authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files
parent 3efc43f5
......@@ -5,14 +5,15 @@ from lightx2v.utils.envs import *
class SekoAudioEncoderModel:
def __init__(self, model_path, audio_sr, cpu_offload, device):
def __init__(self, model_path, audio_sr, cpu_offload, run_device):
self.model_path = model_path
self.audio_sr = audio_sr
self.cpu_offload = cpu_offload
if self.cpu_offload:
self.device = torch.device("cpu")
else:
self.device = torch.device(device)
self.device = torch.device(run_device)
self.run_device = run_device
self.load()
def load(self):
......@@ -26,13 +27,13 @@ class SekoAudioEncoderModel:
self.audio_feature_encoder = self.audio_feature_encoder.to("cpu")
def to_cuda(self):
self.audio_feature_encoder = self.audio_feature_encoder.to("cuda")
self.audio_feature_encoder = self.audio_feature_encoder.to(self.run_device)
@torch.no_grad()
def infer(self, audio_segment):
audio_feat = self.audio_feature_extractor(audio_segment, sampling_rate=self.audio_sr, return_tensors="pt").input_values.to(self.device).to(dtype=GET_DTYPE())
audio_feat = self.audio_feature_extractor(audio_segment, sampling_rate=self.audio_sr, return_tensors="pt").input_values.to(self.run_device).to(dtype=GET_DTYPE())
if self.cpu_offload:
self.audio_feature_encoder = self.audio_feature_encoder.to("cuda")
self.audio_feature_encoder = self.audio_feature_encoder.to(self.run_device)
audio_feat = self.audio_feature_encoder(audio_feat, return_dict=True).last_hidden_state
if self.cpu_offload:
self.audio_feature_encoder = self.audio_feature_encoder.to("cpu")
......
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
# 1. 标准库导入
import gc
import math
import os
import sys
from pathlib import Path
# 2. 第三方库导入
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -747,6 +745,7 @@ class T5EncoderModel:
text_len,
dtype=torch.bfloat16,
device=torch.device("cuda"),
run_device=torch.device("cuda"),
checkpoint_path=None,
tokenizer_path=None,
shard_fn=None,
......@@ -759,6 +758,7 @@ class T5EncoderModel:
self.text_len = text_len
self.dtype = dtype
self.device = device
self.run_device = run_device
if t5_quantized_ckpt is not None and t5_quantized:
self.checkpoint_path = t5_quantized_ckpt
else:
......@@ -807,8 +807,8 @@ class T5EncoderModel:
def infer(self, texts):
ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True)
ids = ids.to(self.device)
mask = mask.to(self.device)
ids = ids.to(self.run_device)
mask = mask.to(self.run_device)
seq_lens = mask.gt(0).sum(dim=1).long()
with torch.no_grad():
......
......@@ -292,7 +292,7 @@ class VisionTransformer(nn.Module):
b = x.size(0)
# embeddings
x = self.patch_embedding(x.type(self.patch_embedding.weight.type())).flatten(2).permute(0, 2, 1)
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
if self.pool_type in ("token", "token_fc"):
x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
if interpolation:
......@@ -426,9 +426,10 @@ def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-r
class CLIPModel:
def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True, load_from_rank0=False):
def __init__(self, dtype, device, checkpoint_path, clip_quantized, clip_quantized_ckpt, quant_scheme, cpu_offload=False, use_31_block=True, load_from_rank0=False, run_device=torch.device("cuda")):
self.dtype = dtype
self.device = device
self.run_device = run_device
self.quantized = clip_quantized
self.cpu_offload = cpu_offload
self.use_31_block = use_31_block
......@@ -462,7 +463,7 @@ class CLIPModel:
return out
def to_cuda(self):
self.model = self.model.cuda()
self.model = self.model.to(self.run_device)
def to_cpu(self):
self.model = self.model.cpu()
import torch
from einops import rearrange
from flash_attn import flash_attn_varlen_qkvpacked_func
from flash_attn.bert_padding import pad_input, unpad_input
from loguru import logger
try:
from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
except ImportError:
flash_attn_varlen_func_v3 = None
logger.info("flash_attn_varlen_func_v3 not available")
if torch.cuda.get_device_capability(0) in [(8, 9), (12, 0)]:
try:
from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn
except ImportError:
logger.info("sageattn not found, please install sageattention first")
sageattn = None
else:
try:
from sageattention import sageattn
except ImportError:
logger.info("sageattn not found, please install sageattention first")
sageattn = None
try:
from sageattn3 import sageattn3_blackwell
except ImportError:
logger.info("sageattn3 not found, please install sageattention first")
sageattn3_blackwell = None
def flash_attn_no_pad(qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None, deterministic=False):
batch_size = qkv.shape[0]
seqlen = qkv.shape[1]
nheads = qkv.shape[-2]
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch = unpad_input(x, key_padding_mask)
x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
output_unpad = flash_attn_varlen_qkvpacked_func(
x_unpad,
cu_seqlens,
max_s,
dropout_p,
softmax_scale=softmax_scale,
causal=causal,
deterministic=deterministic,
)
output = rearrange(
pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen),
"b s (h d) -> b s h d",
h=nheads,
)
return output
def flash_attn_no_pad_v3(qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None, deterministic=False):
if flash_attn_varlen_func_v3 is None:
raise ImportError("FlashAttention V3 backend not available")
batch_size, seqlen, _, nheads, head_dim = qkv.shape
query, key, value = qkv.unbind(dim=2)
query_unpad, indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input(rearrange(query, "b s h d -> b s (h d)"), key_padding_mask)
key_unpad, _, cu_seqlens_k, _, _ = unpad_input(rearrange(key, "b s h d -> b s (h d)"), key_padding_mask)
value_unpad, _, _, _, _ = unpad_input(rearrange(value, "b s h d -> b s (h d)"), key_padding_mask)
query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=nheads)
key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=nheads)
value_unpad = rearrange(value_unpad, "nnz (h d) -> nnz h d", h=nheads)
output_unpad = flash_attn_varlen_func_v3(
query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_q, softmax_scale=softmax_scale, causal=causal, deterministic=deterministic
)
output = rearrange(pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen), "b s (h d) -> b s h d", h=nheads)
return output
def sage_attn_no_pad_v2(qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None, deterministic=False):
batch_size, seqlen, _, nheads, head_dim = qkv.shape
query, key, value = qkv.unbind(dim=2)
query_unpad, indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input(rearrange(query, "b s h d -> b s (h d)"), key_padding_mask)
key_unpad, _, cu_seqlens_k, _, _ = unpad_input(rearrange(key, "b s h d -> b s (h d)"), key_padding_mask)
value_unpad, _, _, _, _ = unpad_input(rearrange(value, "b s h d -> b s (h d)"), key_padding_mask)
query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=nheads)
key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=nheads)
value_unpad = rearrange(value_unpad, "nnz (h d) -> nnz h d", h=nheads)
output_unpad = sageattn(
query_unpad.unsqueeze(0),
key_unpad.unsqueeze(0),
value_unpad.unsqueeze(0),
tensor_layout="NHD",
).squeeze(0)
output = rearrange(pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen), "b s (h d) -> b s h d", h=nheads)
return output
def sage_attn_no_pad_v3(qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None, deterministic=False):
batch_size, seqlen, _, nheads, head_dim = qkv.shape
query, key, value = qkv.unbind(dim=2)
query_unpad, indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input(rearrange(query, "b s h d -> b s (h d)"), key_padding_mask)
key_unpad, _, cu_seqlens_k, _, _ = unpad_input(rearrange(key, "b s h d -> b s (h d)"), key_padding_mask)
value_unpad, _, _, _, _ = unpad_input(rearrange(value, "b s h d -> b s (h d)"), key_padding_mask)
query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=nheads)
key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=nheads)
value_unpad = rearrange(value_unpad, "nnz (h d) -> nnz h d", h=nheads)
output_unpad = sageattn3_blackwell(query_unpad.unsqueeze(0).transpose(1, 2), key_unpad.unsqueeze(0).transpose(1, 2), value_unpad.unsqueeze(0).transpose(1, 2)).transpose(1, 2).squeeze(0)
output = rearrange(pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen), "b s (h d) -> b s h d", h=nheads)
return output
import gc
import json
import numpy as np
import torch
import torch.nn.functional as F
from lightx2v.models.networks.hunyuan_video.infer.offload.transformer_infer import HunyuanVideo15OffloadTransformerInfer
class HunyuanVideo15TransformerInferMagCaching(HunyuanVideo15OffloadTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.magcache_thresh = config.get("magcache_thresh", 0.2)
self.K = config.get("magcache_K", 6)
self.retention_ratio = config.get("magcache_retention_ratio", 0.2)
self.mag_ratios = np.array(config.get("magcache_ratios", []))
self.enable_magcache_calibration = config.get("magcache_calibration", True)
# {True: cond_param, False: uncond_param}
self.accumulated_err = {True: 0.0, False: 0.0}
self.accumulated_steps = {True: 0, False: 0}
self.accumulated_ratio = {True: 1.0, False: 1.0}
self.residual_cache = {True: None, False: None}
self.residual_cache_txt = {True: None, False: None}
# calibration args
self.norm_ratio = [[1.0], [1.0]] # mean of magnitude ratio
self.norm_std = [[0.0], [0.0]] # std of magnitude ratio
self.cos_dis = [[0.0], [0.0]] # cosine distance of residual features
@torch.no_grad()
def infer(self, weights, infer_module_out):
skip_forward = False
step_index = self.scheduler.step_index
infer_condition = self.scheduler.infer_condition
if self.enable_magcache_calibration:
skip_forward = False
else:
if step_index >= int(self.config["infer_steps"] * self.retention_ratio):
# conditional and unconditional in one list
cur_mag_ratio = self.mag_ratios[0][step_index] if infer_condition else self.mag_ratios[1][step_index]
# magnitude ratio between current step and the cached step
self.accumulated_ratio[infer_condition] = self.accumulated_ratio[infer_condition] * cur_mag_ratio
self.accumulated_steps[infer_condition] += 1 # skip steps plus 1
# skip error of current steps
cur_skip_err = np.abs(1 - self.accumulated_ratio[infer_condition])
# accumulated error of multiple steps
self.accumulated_err[infer_condition] += cur_skip_err
if self.accumulated_err[infer_condition] < self.magcache_thresh and self.accumulated_steps[infer_condition] <= self.K:
skip_forward = True
else:
self.accumulated_err[infer_condition] = 0
self.accumulated_steps[infer_condition] = 0
self.accumulated_ratio[infer_condition] = 1.0
if not skip_forward:
self.infer_calculating(weights, infer_module_out)
else:
self.infer_using_cache(infer_module_out)
x = self.infer_final_layer(weights, infer_module_out)
return x
def infer_calculating(self, weights, infer_module_out):
step_index = self.scheduler.step_index
infer_condition = self.scheduler.infer_condition
ori_img = infer_module_out.img.clone()
ori_txt = infer_module_out.txt.clone()
self.infer_func(weights, infer_module_out)
previous_residual = infer_module_out.img - ori_img
previous_residual_txt = infer_module_out.txt - ori_txt
if self.config["cpu_offload"]:
previous_residual = previous_residual.cpu()
previous_residual_txt = previous_residual_txt.cpu()
if self.enable_magcache_calibration and step_index >= 1:
norm_ratio = ((previous_residual.norm(dim=-1) / self.residual_cache[infer_condition].norm(dim=-1)).mean()).item()
norm_std = (previous_residual.norm(dim=-1) / self.residual_cache[infer_condition].norm(dim=-1)).std().item()
cos_dis = (1 - F.cosine_similarity(previous_residual, self.residual_cache[infer_condition], dim=-1, eps=1e-8)).mean().item()
_index = int(not infer_condition)
self.norm_ratio[_index].append(round(norm_ratio, 5))
self.norm_std[_index].append(round(norm_std, 5))
self.cos_dis[_index].append(round(cos_dis, 5))
print(f"time: {step_index}, infer_condition: {infer_condition}, norm_ratio: {norm_ratio}, norm_std: {norm_std}, cos_dis: {cos_dis}")
self.residual_cache[infer_condition] = previous_residual
self.residual_cache_txt[infer_condition] = previous_residual_txt
if self.config["cpu_offload"]:
ori_img = ori_img.to("cpu")
ori_txt = ori_txt.to("cpu")
del ori_img, ori_txt
torch.cuda.empty_cache()
gc.collect()
def infer_using_cache(self, infer_module_out):
residual_img = self.residual_cache[self.scheduler.infer_condition]
residual_txt = self.residual_cache_txt[self.scheduler.infer_condition]
infer_module_out.img.add_(residual_img.cuda())
infer_module_out.txt.add_(residual_txt.cuda())
def clear(self):
self.accumulated_err = {True: 0.0, False: 0.0}
self.accumulated_steps = {True: 0, False: 0}
self.accumulated_ratio = {True: 1.0, False: 1.0}
self.residual_cache = {True: None, False: None}
self.residual_cache_txt = {True: None, False: None}
if self.enable_magcache_calibration:
print("norm ratio")
print(self.norm_ratio)
print("norm std")
print(self.norm_std)
print("cos_dis")
print(self.cos_dis)
def save_json(filename, obj_list):
with open(filename + ".json", "w") as f:
json.dump(obj_list, f)
save_json("mag_ratio", self.norm_ratio)
save_json("mag_std", self.norm_std)
save_json("cos_dis", self.cos_dis)
class HunyuanTransformerInferTeaCaching(HunyuanVideo15OffloadTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.teacache_thresh = self.config["teacache_thresh"]
self.coefficients = self.config["coefficients"]
self.accumulated_rel_l1_distance_odd = 0
self.previous_modulated_input_odd = None
self.previous_residual_odd = None
self.accumulated_rel_l1_distance_even = 0
self.previous_modulated_input_even = None
self.previous_residual_even = None
def calculate_should_calc(self, img, vec, block):
inp = img.clone()
vec_ = vec.clone()
img_mod_layer = block.img_branch.img_mod
if self.config["cpu_offload"]:
img_mod_layer.to_cuda()
img_mod1_shift, img_mod1_scale, _, _, _, _ = img_mod_layer.apply(vec_).chunk(6, dim=-1)
inp = inp.squeeze(0)
normed_inp = torch.nn.functional.layer_norm(inp, (inp.shape[1],), None, None, 1e-6)
modulated_inp = normed_inp * (1 + img_mod1_scale) + img_mod1_shift
del normed_inp, inp, vec_
if self.scheduler.step_index == 0 or self.scheduler.step_index == self.scheduler.infer_steps - 1:
should_calc = True
if self.scheduler.infer_condition:
self.accumulated_rel_l1_distance_odd = 0
self.previous_modulated_input_odd = modulated_inp
else:
self.accumulated_rel_l1_distance_even = 0
self.previous_modulated_input_even = modulated_inp
else:
rescale_func = np.poly1d(self.coefficients)
if self.scheduler.infer_condition:
self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp - self.previous_modulated_input_odd).abs().mean() / self.previous_modulated_input_odd.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance_odd = 0
self.previous_modulated_input_odd = modulated_inp
else:
self.accumulated_rel_l1_distance_even += rescale_func(
((modulated_inp - self.previous_modulated_input_even).abs().mean() / self.previous_modulated_input_even.abs().mean()).cpu().item()
)
if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance_even = 0
self.previous_modulated_input_even = modulated_inp
del modulated_inp
return should_calc
def infer(self, weights, infer_module_out):
should_calc = self.calculate_should_calc(infer_module_out.img, infer_module_out.vec, weights.double_blocks[0])
if not should_calc:
if self.scheduler.infer_condition:
infer_module_out.img += self.previous_residual_odd
else:
infer_module_out.img += self.previous_residual_even
else:
ori_img = infer_module_out.img.clone()
self.infer_func(weights, infer_module_out)
if self.scheduler.infer_condition:
self.previous_residual_odd = infer_module_out.img - ori_img
else:
self.previous_residual_even = infer_module_out.img - ori_img
x = self.infer_final_layer(weights, infer_module_out)
return x
def clear(self):
if self.previous_residual_odd is not None:
self.previous_residual_odd = self.previous_residual_odd.cpu()
if self.previous_modulated_input_odd is not None:
self.previous_modulated_input_odd = self.previous_modulated_input_odd.cpu()
if self.previous_residual_even is not None:
self.previous_residual_even = self.previous_residual_even.cpu()
if self.previous_modulated_input_even is not None:
self.previous_modulated_input_even = self.previous_modulated_input_even.cpu()
self.previous_modulated_input_odd = None
self.previous_residual_odd = None
self.previous_modulated_input_even = None
self.previous_residual_even = None
torch.cuda.empty_cache()
from dataclasses import dataclass
import torch
@dataclass
class HunyuanVideo15InferModuleOutput:
img: torch.Tensor
txt: torch.Tensor
vec: torch.Tensor
grid_sizes: tuple
@dataclass
class HunyuanVideo15ImgBranchOutput:
img_mod1_gate: torch.Tensor
img_mod2_shift: torch.Tensor
img_mod2_scale: torch.Tensor
img_mod2_gate: torch.Tensor
@dataclass
class HunyuanVideo15TxtBranchOutput:
txt_mod1_gate: torch.Tensor
txt_mod2_shift: torch.Tensor
txt_mod2_scale: torch.Tensor
txt_mod2_gate: torch.Tensor
import torch
from lightx2v.common.offload.manager import WeightAsyncStreamManager
from lightx2v.models.networks.hunyuan_video.infer.transformer_infer import HunyuanVideo15TransformerInfer
class HunyuanVideo15OffloadTransformerInfer(HunyuanVideo15TransformerInfer):
def __init__(self, config):
super().__init__(config)
if self.config.get("cpu_offload", False):
offload_granularity = self.config.get("offload_granularity", "block")
if offload_granularity == "block":
self.infer_func = self.infer_with_blocks_offload
elif offload_granularity == "model":
self.infer_func = self.infer_without_offload
else:
raise NotImplementedError
if offload_granularity != "model":
self.offload_manager = WeightAsyncStreamManager(offload_granularity=offload_granularity)
@torch.no_grad()
def infer_with_blocks_offload(self, weights, infer_module_out):
for block_idx in range(self.double_blocks_num):
self.block_idx = block_idx
if block_idx == 0:
self.offload_manager.init_first_buffer(weights.double_blocks)
if block_idx < self.double_blocks_num - 1:
self.offload_manager.prefetch_weights(block_idx + 1, weights.double_blocks)
with torch.cuda.stream(self.offload_manager.compute_stream):
infer_module_out.img, infer_module_out.txt = self.infer_double_block(self.offload_manager.cuda_buffers[0], infer_module_out)
self.offload_manager.swap_blocks()
import torch
from lightx2v.utils.envs import *
class HunyuanVideo15PostInfer:
def __init__(self, config):
self.config = config
self.unpatchify_channels = config["out_channels"]
self.patch_size = config["patch_size"] # (1, 1, 1)
def set_scheduler(self, scheduler):
self.scheduler = scheduler
@torch.no_grad()
def infer(self, x, pre_infer_out):
x = self.unpatchify(x, pre_infer_out.grid_sizes[0], pre_infer_out.grid_sizes[1], pre_infer_out.grid_sizes[2])
return x
def unpatchify(self, x, t, h, w):
"""
Unpatchify a tensorized input back to frame format.
Args:
x (Tensor): Input tensor of shape (N, T, patch_size**2 * C)
t (int): Number of time steps
h (int): Height in patch units
w (int): Width in patch units
Returns:
Tensor: Output tensor of shape (N, C, t * pt, h * ph, w * pw)
"""
c = self.unpatchify_channels
pt, ph, pw = self.patch_size
x = x[:, : t * h * w] # remove padding from seq parallel
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
x = torch.einsum("nthwcopq->nctohpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
import math
from typing import Optional
import torch
from einops import rearrange
from lightx2v.utils.envs import *
from .attn_no_pad import flash_attn_no_pad, flash_attn_no_pad_v3, sage_attn_no_pad_v2
from .module_io import HunyuanVideo15InferModuleOutput
def apply_gate(x, gate=None, tanh=False):
"""AI is creating summary for apply_gate
Args:
x (torch.Tensor): input tensor.
gate (torch.Tensor, optional): gate tensor. Defaults to None.
tanh (bool, optional): whether to use tanh function. Defaults to False.
Returns:
torch.Tensor: the output tensor after apply gate.
"""
if gate is None:
return x
if tanh:
return x * gate.unsqueeze(1).tanh()
else:
return x * gate.unsqueeze(1)
@torch.compiler.disable
def attention(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, drop_rate: float = 0.0, attn_mask: Optional[torch.Tensor] = None, causal: bool = False, attn_type: str = "flash_attn2"
) -> torch.Tensor:
"""
Compute attention using flash_attn_no_pad.
Args:
q: Query tensor of shape [B, L, H, D]
k: Key tensor of shape [B, L, H, D]
v: Value tensor of shape [B, L, H, D]
drop_rate: Dropout rate for attention weights.
attn_mask: Optional attention mask of shape [B, L].
causal: Whether to apply causal masking.
Returns:
Output tensor after attention of shape [B, L, H*D]
"""
qkv = torch.stack([q, k, v], dim=2)
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.bool()
if attn_type == "flash_attn2":
x = flash_attn_no_pad(qkv, attn_mask, causal=causal, dropout_p=drop_rate, softmax_scale=None)
elif attn_type == "flash_attn3":
x = flash_attn_no_pad_v3(qkv, attn_mask, causal=causal, dropout_p=drop_rate, softmax_scale=None)
elif attn_type == "sage_attn2":
x = sage_attn_no_pad_v2(qkv, attn_mask, causal=causal, dropout_p=drop_rate, softmax_scale=None)
b, s, a, d = x.shape
out = x.reshape(b, s, -1)
return out
class HunyuanVideo15PreInfer:
def __init__(self, config):
self.config = config
self.patch_size = config["patch_size"]
self.heads_num = config["heads_num"]
self.frequency_embedding_size = 256
self.max_period = 10000
def set_scheduler(self, scheduler):
self.scheduler = scheduler
@torch.no_grad()
def infer(self, weights, inputs):
latents = self.scheduler.latents
grid_sizes_t, grid_sizes_h, grid_sizes_w = latents.shape[2:]
timesteps = self.scheduler.timesteps
t = timesteps[self.scheduler.step_index]
if self.scheduler.infer_condition:
txt, text_mask = inputs["text_encoder_output"]["context"][0], inputs["text_encoder_output"]["context"][1]
else:
txt, text_mask = inputs["text_encoder_output"]["context_null"][0], inputs["text_encoder_output"]["context_null"][1]
byt5_txt, byt5_text_mask = inputs["text_encoder_output"]["byt5_features"], inputs["text_encoder_output"]["byt5_masks"]
siglip_output, siglip_mask = inputs["image_encoder_output"]["siglip_output"], inputs["image_encoder_output"]["siglip_mask"]
txt = txt.to(torch.bfloat16)
if self.config["is_sr_running"]:
if t < 1000 * self.scheduler.noise_scale:
condition = self.scheduler.zero_condition
else:
condition = self.scheduler.condition
img = x = latent_model_input = torch.concat([latents, condition], dim=1)
else:
cond_latents_concat = self.scheduler.cond_latents_concat
mask_concat = self.scheduler.mask_concat
img = x = latent_model_input = torch.concat([latents, cond_latents_concat, mask_concat], dim=1)
img = img.to(torch.bfloat16)
t_expand = t.repeat(latent_model_input.shape[0])
guidance_expand = None
img = weights.img_in.apply(img)
img = img.flatten(2).transpose(1, 2)
t_freq = self.timestep_embedding(t_expand, self.frequency_embedding_size, self.max_period).to(torch.bfloat16)
vec = weights.time_in_0.apply(t_freq)
vec = torch.nn.functional.silu(vec)
vec = weights.time_in_2.apply(vec)
if self.config["is_sr_running"]:
use_meanflow = self.config.get("video_super_resolution", {}).get("use_meanflow", False)
if use_meanflow:
if self.scheduler.step_index == len(timesteps) - 1:
timesteps_r = torch.tensor([0.0], device=latent_model_input.device)
else:
timesteps_r = timesteps[self.scheduler.step_index + 1]
timesteps_r = timesteps_r.repeat(latent_model_input.shape[0])
else:
timesteps_r = None
if timesteps_r is not None:
t_freq = self.timestep_embedding(timesteps_r, self.frequency_embedding_size, self.max_period).to(torch.bfloat16)
vec_res = weights.time_r_in_0.apply(t_freq)
vec_res = torch.nn.functional.silu(vec_res)
vec_res = weights.time_r_in_2.apply(vec_res)
vec = vec + vec_res
t_freq = self.timestep_embedding(t_expand, self.frequency_embedding_size, self.max_period).to(torch.bfloat16)
timestep_aware_representations = weights.txt_in_t_embedder_0.apply(t_freq)
timestep_aware_representations = torch.nn.functional.silu(timestep_aware_representations)
timestep_aware_representations = weights.txt_in_t_embedder_2.apply(timestep_aware_representations)
mask_float = text_mask.float().unsqueeze(-1)
context_aware_representations = (txt * mask_float).sum(dim=1) / mask_float.sum(dim=1)
context_aware_representations = context_aware_representations.to(torch.bfloat16)
context_aware_representations = weights.txt_in_c_embedder_0.apply(context_aware_representations)
context_aware_representations = torch.nn.functional.silu(context_aware_representations)
context_aware_representations = weights.txt_in_c_embedder_2.apply(context_aware_representations)
c = timestep_aware_representations + context_aware_representations
out = weights.txt_in_input_embedder.apply(txt[0].to(torch.bfloat16))
txt = self.run_individual_token_refiner(weights, out, text_mask, c)
# TODO: 可以删除这段计算
txt = txt.unsqueeze(0)
txt = txt + weights.cond_type_embedding.apply(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long))
byt5_txt = byt5_txt + weights.cond_type_embedding.apply(torch.ones_like(byt5_txt[:, :, 0], device=byt5_txt.device, dtype=torch.long))
txt, text_mask = self.reorder_txt_token(byt5_txt, txt, byt5_text_mask, text_mask, zero_feat=True)
siglip_output = siglip_output + weights.cond_type_embedding.apply(2 * torch.ones_like(siglip_output[:, :, 0], dtype=torch.long, device=torch.device("cuda")))
txt, text_mask = self.reorder_txt_token(siglip_output, txt, siglip_mask, text_mask)
txt = txt[:, : text_mask.sum(), :]
return HunyuanVideo15InferModuleOutput(
img=img.contiguous(),
txt=txt.contiguous(),
vec=torch.nn.functional.silu(vec),
grid_sizes=(grid_sizes_t, grid_sizes_h, grid_sizes_w),
)
def run_individual_token_refiner(self, weights, out, mask, c):
mask = mask.clone().bool()
mask[:, 0] = True # Prevent attention weights from becoming NaN
for block in weights.individual_token_refiner: # block num = 2
gate_msa, gate_mlp = self.adaLN_modulation(block, c)
norm_x = block.norm1.apply(out.unsqueeze(0)).squeeze(0)
qkv = block.self_attn_qkv.apply(norm_x).unsqueeze(0)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
attn = attention(q, k, v, attn_mask=mask, attn_type="flash_attn2").squeeze(0)
out = out + apply_gate(block.self_attn_proj.apply(attn).unsqueeze(0), gate_msa).squeeze(0)
tmp = block.mlp_fc1.apply(block.norm2.apply(out))
tmp = torch.nn.functional.silu(tmp)
tmp = block.mlp_fc2.apply(tmp)
out = out + apply_gate(tmp.unsqueeze(0), gate_mlp).squeeze(0)
return out
def adaLN_modulation(self, weights, c):
c = torch.nn.functional.silu(c)
gate_msa, gate_mlp = weights.adaLN_modulation.apply(c).chunk(2, dim=1)
return gate_msa, gate_mlp
def timestep_embedding(self, t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
Args:
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
dim (int): the dimension of the output.
max_period (int): controls the minimum frequency of the embeddings.
Returns:
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def reorder_txt_token(self, byt5_txt, txt, byt5_text_mask, text_mask, zero_feat=False, is_reorder=True):
if is_reorder:
reorder_txt = []
reorder_mask = []
for i in range(text_mask.shape[0]):
byt5_text_mask_i = byt5_text_mask[i].bool()
text_mask_i = text_mask[i].bool()
byt5_txt_i = byt5_txt[i]
txt_i = txt[i]
if zero_feat:
# When using block mask with approximate computation, set pad to zero to reduce error
pad_byt5 = torch.zeros_like(byt5_txt_i[~byt5_text_mask_i])
pad_text = torch.zeros_like(txt_i[~text_mask_i])
reorder_txt_i = torch.cat([byt5_txt_i[byt5_text_mask_i], txt_i[text_mask_i], pad_byt5, pad_text], dim=0)
else:
reorder_txt_i = torch.cat([byt5_txt_i[byt5_text_mask_i], txt_i[text_mask_i], byt5_txt_i[~byt5_text_mask_i], txt_i[~text_mask_i]], dim=0)
reorder_mask_i = torch.cat([byt5_text_mask_i[byt5_text_mask_i], text_mask_i[text_mask_i], byt5_text_mask_i[~byt5_text_mask_i], text_mask_i[~text_mask_i]], dim=0)
reorder_txt.append(reorder_txt_i)
reorder_mask.append(reorder_mask_i)
reorder_txt = torch.stack(reorder_txt)
reorder_mask = torch.stack(reorder_mask).to(dtype=torch.int64)
else:
reorder_txt = torch.concat([byt5_txt, txt], dim=1)
reorder_mask = torch.concat([byt5_text_mask, text_mask], dim=1).to(dtype=torch.int64)
return reorder_txt, reorder_mask
from typing import Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from .module_io import HunyuanVideo15ImgBranchOutput, HunyuanVideo15TxtBranchOutput
from .triton_ops import fuse_scale_shift_kernel
def modulate(x, scale, shift):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def apply_gate(x, gate=None, tanh=False):
"""AI is creating summary for apply_gate
Args:
x (torch.Tensor): input tensor.
gate (torch.Tensor, optional): gate tensor. Defaults to None.
tanh (bool, optional): whether to use tanh function. Defaults to False.
Returns:
torch.Tensor: the output tensor after apply gate.
"""
if gate is None:
return x
if tanh:
return x * gate.unsqueeze(1).tanh()
else:
return x * gate.unsqueeze(1)
def apply_hunyuan_rope_with_flashinfer(
xq: torch.Tensor,
xk: torch.Tensor,
cos_sin_cache: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
B, L, H, D = xq.shape
query = xq.reshape(B * L, H * D).contiguous()
key = xk.reshape(B * L, H * D).contiguous()
positions = torch.arange(B * L, device=xq.device, dtype=torch.long)
apply_rope_with_cos_sin_cache_inplace(
positions=positions,
query=query,
key=key,
head_size=D,
cos_sin_cache=cos_sin_cache,
is_neox=False,
)
xq_out = query.view(B, L, H, D)
xk_out = key.view(B, L, H, D)
return xq_out, xk_out
def apply_hunyuan_rope_with_torch(
xq: torch.Tensor,
xk: torch.Tensor,
cos_sin_cache: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
B, L, H, D = xq.shape
cos = cos_sin_cache[:, : D // 2]
sin = cos_sin_cache[:, D // 2 :]
def _apply_rope(x: torch.Tensor) -> torch.Tensor:
x_flat = x.view(B * L, H, D)
x1 = x_flat[..., ::2]
x2 = x_flat[..., 1::2]
cos_ = cos.unsqueeze(1)
sin_ = sin.unsqueeze(1)
o1 = x1.float() * cos_ - x2.float() * sin_
o2 = x2.float() * cos_ + x1.float() * sin_
out = torch.empty_like(x_flat)
out[..., ::2] = o1
out[..., 1::2] = o2
return out.view(B, L, H, D)
xq_out = _apply_rope(xq)
xk_out = _apply_rope(xk)
return xq_out, xk_out
class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
def __init__(self, config):
self.config = config
self.double_blocks_num = config["mm_double_blocks_depth"]
self.heads_num = config["heads_num"]
if self.config["seq_parallel"]:
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
else:
self.seq_p_group = None
self.infer_func = self.infer_without_offload
if self.config.get("modulate_type", "triton") == "triton":
self.modulate_func = fuse_scale_shift_kernel
else:
self.modulate_func = modulate
if self.config.get("rope_type", "flashinfer") == "flashinfer":
self.apply_rope_func = apply_hunyuan_rope_with_flashinfer
else:
self.apply_rope_func = apply_hunyuan_rope_with_torch
def set_scheduler(self, scheduler):
self.scheduler = scheduler
self.scheduler.transformer_infer = self
@torch.no_grad()
def infer(self, weights, infer_module_out):
self.infer_func(weights, infer_module_out)
x = self.infer_final_layer(weights, infer_module_out)
return x
@torch.no_grad()
def infer_without_offload(self, weights, infer_module_out):
for i in range(self.double_blocks_num):
infer_module_out.img, infer_module_out.txt = self.infer_double_block(weights.double_blocks[i], infer_module_out)
@torch.no_grad()
def infer_final_layer(self, weights, infer_module_out):
x = torch.cat((infer_module_out.img, infer_module_out.txt), 1)
img = x[:, : infer_module_out.img.shape[1], ...]
shift, scale = weights.final_layer.adaLN_modulation.apply(infer_module_out.vec).chunk(2, dim=1)
img = self.modulate_func(weights.final_layer.norm_final.apply(img.squeeze(0)), scale=scale, shift=shift).squeeze(0)
img = weights.final_layer.linear.apply(img)
return img.unsqueeze(0)
@torch.no_grad()
def infer_double_block(self, weights, infer_module_out):
img_q, img_k, img_v, img_branch_out = self._infer_img_branch_before_attn(weights, infer_module_out)
txt_q, txt_k, txt_v, txt_branch_out = self._infer_txt_branch_before_attn(weights, infer_module_out)
img_attn, txt_attn = self._infer_attn(weights, img_q, img_k, img_v, txt_q, txt_k, txt_v)
img = self._infer_img_branch_after_attn(weights, img_attn, infer_module_out.img, img_branch_out)
txt = self._infer_txt_branch_after_attn(weights, txt_attn, infer_module_out.txt, txt_branch_out)
return img, txt
@torch.no_grad()
def _infer_img_branch_before_attn(self, weights, infer_module_out):
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = weights.img_branch.img_mod.apply(infer_module_out.vec).chunk(6, dim=-1)
img_modulated = weights.img_branch.img_norm1.apply(infer_module_out.img.squeeze(0))
img_modulated = self.modulate_func(img_modulated, scale=img_mod1_scale, shift=img_mod1_shift).squeeze(0)
img_q = weights.img_branch.img_attn_q.apply(img_modulated)
img_k = weights.img_branch.img_attn_k.apply(img_modulated)
img_v = weights.img_branch.img_attn_v.apply(img_modulated)
img_q = rearrange(img_q, "L (H D) -> L H D", H=self.heads_num)
img_k = rearrange(img_k, "L (H D) -> L H D", H=self.heads_num)
img_v = rearrange(img_v, "L (H D) -> L H D", H=self.heads_num)
img_q = weights.img_branch.img_attn_q_norm.apply(img_q)
img_k = weights.img_branch.img_attn_k_norm.apply(img_k)
img_q, img_k = self.apply_rope_func(img_q.unsqueeze(0), img_k.unsqueeze(0), cos_sin_cache=self.scheduler.cos_sin)
return (
img_q,
img_k,
img_v.unsqueeze(0),
HunyuanVideo15ImgBranchOutput(
img_mod1_gate=img_mod1_gate,
img_mod2_shift=img_mod2_shift,
img_mod2_scale=img_mod2_scale,
img_mod2_gate=img_mod2_gate,
),
)
@torch.no_grad()
def _infer_txt_branch_before_attn(self, weights, infer_module_out):
(
txt_mod1_shift,
txt_mod1_scale,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = weights.txt_branch.txt_mod.apply(infer_module_out.vec).chunk(6, dim=-1)
txt_modulated = weights.txt_branch.txt_norm1.apply(infer_module_out.txt.squeeze(0))
txt_modulated = self.modulate_func(txt_modulated, scale=txt_mod1_scale, shift=txt_mod1_shift).squeeze(0)
txt_q = weights.txt_branch.txt_attn_q.apply(txt_modulated)
txt_k = weights.txt_branch.txt_attn_k.apply(txt_modulated)
txt_v = weights.txt_branch.txt_attn_v.apply(txt_modulated)
txt_q = rearrange(txt_q, "L (H D) -> L H D", H=self.heads_num)
txt_k = rearrange(txt_k, "L (H D) -> L H D", H=self.heads_num)
txt_v = rearrange(txt_v, "L (H D) -> L H D", H=self.heads_num)
txt_q = weights.txt_branch.txt_attn_q_norm.apply(txt_q).to(txt_v)
txt_k = weights.txt_branch.txt_attn_k_norm.apply(txt_k).to(txt_v)
return (
txt_q.unsqueeze(0),
txt_k.unsqueeze(0),
txt_v.unsqueeze(0),
HunyuanVideo15TxtBranchOutput(
txt_mod1_gate=txt_mod1_gate,
txt_mod2_shift=txt_mod2_shift,
txt_mod2_scale=txt_mod2_scale,
txt_mod2_gate=txt_mod2_gate,
),
)
@torch.no_grad()
def _infer_attn(self, weights, img_q, img_k, img_v, txt_q, txt_k, txt_v):
img_seqlen = img_q.shape[1]
query = torch.cat([img_q, txt_q], dim=1)
key = torch.cat([img_k, txt_k], dim=1)
value = torch.cat([img_v, txt_v], dim=1)
seqlen = query.shape[1]
cu_seqlens_qkv = torch.tensor([0, seqlen], dtype=torch.int32, device="cpu").to("cuda", non_blocking=True)
if self.config["seq_parallel"]:
attn_out = weights.self_attention_parallel.apply(
q=query,
k=key,
v=value,
img_qkv_len=img_seqlen,
cu_seqlens_qkv=cu_seqlens_qkv,
attention_module=weights.self_attention,
seq_p_group=self.seq_p_group,
)
else:
attn_out = weights.self_attention.apply(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_qkv,
cu_seqlens_kv=cu_seqlens_qkv,
max_seqlen_q=seqlen,
max_seqlen_kv=seqlen,
)
img_attn, txt_attn = attn_out[:img_seqlen], attn_out[img_seqlen:]
return img_attn, txt_attn
@torch.no_grad()
def _infer_img_branch_after_attn(self, weights, img_attn, img, img_branch_out):
img = img + apply_gate(weights.img_branch.img_attn_proj.apply(img_attn).unsqueeze(0), gate=img_branch_out.img_mod1_gate)
out = weights.img_branch.img_mlp_fc1.apply(
self.modulate_func(weights.img_branch.img_norm2.apply(img.squeeze(0)), scale=img_branch_out.img_mod2_scale, shift=img_branch_out.img_mod2_shift).squeeze(0)
)
out = weights.img_branch.img_mlp_fc2.apply(F.gelu(out, approximate="tanh"))
img = img + apply_gate(out.unsqueeze(0), gate=img_branch_out.img_mod2_gate)
return img
@torch.no_grad()
def _infer_txt_branch_after_attn(self, weights, txt_attn, txt, txt_branch_out):
txt = txt + apply_gate(weights.txt_branch.txt_attn_proj.apply(txt_attn).unsqueeze(0), gate=txt_branch_out.txt_mod1_gate)
out = weights.txt_branch.txt_mlp_fc1.apply(
self.modulate_func(weights.txt_branch.txt_norm2.apply(txt.squeeze(0)), scale=txt_branch_out.txt_mod2_scale, shift=txt_branch_out.txt_mod2_shift).squeeze(0)
)
out = weights.txt_branch.txt_mlp_fc2.apply(F.gelu(out, approximate="tanh"))
txt = txt + apply_gate(out.unsqueeze(0), gate=txt_branch_out.txt_mod2_gate)
return txt
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo & https://github.com/sgl-project/sglang
# TODO: for temporary usage, expecting a refactor
from typing import Optional
import torch
import triton # type: ignore
import triton.language as tl # type: ignore
from torch import Tensor
@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 64}, num_warps=2),
triton.Config({"BLOCK_N": 128}, num_warps=4),
triton.Config({"BLOCK_N": 256}, num_warps=4),
triton.Config({"BLOCK_N": 512}, num_warps=4),
triton.Config({"BLOCK_N": 1024}, num_warps=8),
],
key=["inner_dim"],
)
@triton.jit
def _fused_scale_shift_4d_kernel(
output_ptr,
normalized_ptr,
scale_ptr,
shift_ptr,
rows,
inner_dim,
seq_len,
num_frames,
frame_seqlen,
BLOCK_N: tl.constexpr,
):
pid_row = tl.program_id(0)
pid_col = tl.program_id(1)
col_offsets = pid_col * BLOCK_N + tl.arange(0, BLOCK_N)
mask = col_offsets < inner_dim
# Pointers for normalized and output
row_base = pid_row * inner_dim
norm_ptrs = normalized_ptr + row_base + col_offsets
out_ptrs = output_ptr + row_base + col_offsets
# Pointers for scale and shift for 4D
b_idx = pid_row // seq_len
t_idx = pid_row % seq_len
frame_idx_in_batch = t_idx // frame_seqlen
scale_row_idx = b_idx * num_frames + frame_idx_in_batch
scale_ptrs = scale_ptr + scale_row_idx * inner_dim + col_offsets
shift_ptrs = shift_ptr + scale_row_idx * inner_dim + col_offsets
normalized = tl.load(norm_ptrs, mask=mask, other=0.0)
scale = tl.load(scale_ptrs, mask=mask, other=0.0)
shift = tl.load(shift_ptrs, mask=mask, other=0.0)
one = tl.full([BLOCK_N], 1.0, dtype=scale.dtype)
output = normalized * (one + scale) + shift
tl.store(out_ptrs, output, mask=mask)
@triton.jit
def fuse_scale_shift_kernel_blc_opt(
x_ptr,
shift_ptr,
scale_ptr,
y_ptr,
B,
L,
C,
stride_x_b,
stride_x_l,
stride_x_c,
stride_s_b,
stride_s_l,
stride_s_c,
stride_sc_b,
stride_sc_l,
stride_sc_c,
SCALE_IS_SCALAR: tl.constexpr,
SHIFT_IS_SCALAR: tl.constexpr,
BLOCK_L: tl.constexpr,
BLOCK_C: tl.constexpr,
):
pid_l = tl.program_id(0)
pid_c = tl.program_id(1)
pid_b = tl.program_id(2)
l_offsets = pid_l * BLOCK_L + tl.arange(0, BLOCK_L)
c_offsets = pid_c * BLOCK_C + tl.arange(0, BLOCK_C)
mask_l = l_offsets < L
mask_c = c_offsets < C
mask = mask_l[:, None] & mask_c[None, :]
x_off = pid_b * stride_x_b + l_offsets[:, None] * stride_x_l + c_offsets[None, :] * stride_x_c
x = tl.load(x_ptr + x_off, mask=mask, other=0)
if SHIFT_IS_SCALAR:
shift_val = tl.load(shift_ptr)
shift = tl.full((BLOCK_L, BLOCK_C), shift_val, dtype=shift_val.dtype)
else:
s_off = pid_b * stride_s_b + l_offsets[:, None] * stride_s_l + c_offsets[None, :] * stride_s_c
shift = tl.load(shift_ptr + s_off, mask=mask, other=0)
if SCALE_IS_SCALAR:
scale_val = tl.load(scale_ptr)
scale = tl.full((BLOCK_L, BLOCK_C), scale_val, dtype=scale_val.dtype)
else:
sc_off = pid_b * stride_sc_b + l_offsets[:, None] * stride_sc_l + c_offsets[None, :] * stride_sc_c
scale = tl.load(scale_ptr + sc_off, mask=mask, other=0)
y = x * (1 + scale) + shift
tl.store(y_ptr + x_off, y, mask=mask)
def fuse_scale_shift_kernel(
x: torch.Tensor,
scale: torch.Tensor,
shift: torch.Tensor,
block_l: int = 128,
block_c: int = 128,
):
assert x.is_cuda and scale.is_cuda
assert x.is_contiguous()
if x.dim() == 2:
x = x.unsqueeze(0)
B, L, C = x.shape
output = torch.empty_like(x)
if scale.dim() == 4:
# scale/shift: [B, F, 1, C]
rows = B * L
x_2d = x.view(rows, C)
output_2d = output.view(rows, C)
grid = lambda META: (rows, triton.cdiv(C, META["BLOCK_N"])) # noqa
num_frames = scale.shape[1]
assert L % num_frames == 0, "seq_len must be divisible by num_frames for 4D scale/shift"
frame_seqlen = L // num_frames
# Compact [B, F, C] without the singleton dim into [B*F, C]
scale_reshaped = scale.squeeze(2).reshape(-1, C).contiguous()
shift_reshaped = shift.squeeze(2).reshape(-1, C).contiguous()
_fused_scale_shift_4d_kernel[grid](
output_2d,
x_2d,
scale_reshaped,
shift_reshaped,
rows,
C,
L,
num_frames,
frame_seqlen,
)
else:
# 2D: [B, C] or [1, C] -> treat as [B, 1, C] and broadcast over L
# 3D: [B, L, C] (or broadcastable variants like [B, 1, C], [1, L, C], [1, 1, C])
# Also support scalar (0D or 1-element)
if scale.dim() == 0 or (scale.dim() == 1 and scale.numel() == 1):
scale_blc = scale.reshape(1)
elif scale.dim() == 2:
scale_blc = scale[:, None, :]
elif scale.dim() == 3:
scale_blc = scale
else:
raise ValueError("scale must be 0D/1D(1)/2D/3D or 4D")
if shift.dim() == 0 or (shift.dim() == 1 and shift.numel() == 1):
shift_blc = shift.reshape(1)
elif shift.dim() == 2:
shift_blc = shift[:, None, :]
elif shift.dim() == 3:
shift_blc = shift
else:
# broadcast later via expand if possible
shift_blc = shift
need_scale_scalar = scale_blc.dim() == 1 and scale_blc.numel() == 1
need_shift_scalar = shift_blc.dim() == 1 and shift_blc.numel() == 1
if not need_scale_scalar:
scale_exp = scale_blc.expand(B, L, C)
s_sb, s_sl, s_sc = scale_exp.stride()
else:
s_sb = s_sl = s_sc = 0
if not need_shift_scalar:
shift_exp = shift_blc.expand(B, L, C)
sh_sb, sh_sl, sh_sc = shift_exp.stride()
else:
sh_sb = sh_sl = sh_sc = 0
# If both scalars and both zero, copy fast-path
if need_scale_scalar and need_shift_scalar:
if (scale_blc.abs().max() == 0) and (shift_blc.abs().max() == 0):
output.copy_(x)
return output
grid = (triton.cdiv(L, block_l), triton.cdiv(C, block_c), B)
fuse_scale_shift_kernel_blc_opt[grid](
x,
shift_blc if need_shift_scalar else shift_exp,
scale_blc if need_scale_scalar else scale_exp,
output,
B,
L,
C,
x.stride(0),
x.stride(1),
x.stride(2),
sh_sb,
sh_sl,
sh_sc,
s_sb,
s_sl,
s_sc,
SCALE_IS_SCALAR=need_scale_scalar,
SHIFT_IS_SCALAR=need_shift_scalar,
BLOCK_L=block_l,
BLOCK_C=block_c,
num_warps=4,
num_stages=2,
)
return output
@triton.autotune(
configs=[
triton.Config({"BLOCK_HS_HALF": 32}, num_warps=2),
triton.Config({"BLOCK_HS_HALF": 64}, num_warps=4),
triton.Config({"BLOCK_HS_HALF": 128}, num_warps=4),
triton.Config({"BLOCK_HS_HALF": 256}, num_warps=8),
],
key=["head_size", "interleaved"],
)
@triton.jit
def _rotary_embedding_kernel(
output_ptr,
x_ptr,
cos_ptr,
sin_ptr,
num_heads,
head_size,
num_tokens,
stride_x_row,
stride_cos_row,
stride_sin_row,
interleaved: tl.constexpr,
BLOCK_HS_HALF: tl.constexpr,
):
row_idx = tl.program_id(0)
token_idx = (row_idx // num_heads) % num_tokens
x_row_ptr = x_ptr + row_idx * stride_x_row
cos_row_ptr = cos_ptr + token_idx * stride_cos_row
sin_row_ptr = sin_ptr + token_idx * stride_sin_row
output_row_ptr = output_ptr + row_idx * stride_x_row
# half size for x1 and x2
head_size_half = head_size // 2
for block_start in range(0, head_size_half, BLOCK_HS_HALF):
offsets_half = block_start + tl.arange(0, BLOCK_HS_HALF)
mask = offsets_half < head_size_half
cos_vals = tl.load(cos_row_ptr + offsets_half, mask=mask, other=0.0)
sin_vals = tl.load(sin_row_ptr + offsets_half, mask=mask, other=0.0)
offsets_x1 = 2 * offsets_half
offsets_x2 = 2 * offsets_half + 1
x1_vals = tl.load(x_row_ptr + offsets_x1, mask=mask, other=0.0)
x2_vals = tl.load(x_row_ptr + offsets_x2, mask=mask, other=0.0)
x1_fp32 = x1_vals.to(tl.float32)
x2_fp32 = x2_vals.to(tl.float32)
cos_fp32 = cos_vals.to(tl.float32)
sin_fp32 = sin_vals.to(tl.float32)
o1_vals = tl.fma(-x2_fp32, sin_fp32, x1_fp32 * cos_fp32)
o2_vals = tl.fma(x1_fp32, sin_fp32, x2_fp32 * cos_fp32)
tl.store(output_row_ptr + offsets_x1, o1_vals.to(x1_vals.dtype), mask=mask)
tl.store(output_row_ptr + offsets_x2, o2_vals.to(x2_vals.dtype), mask=mask)
def apply_rotary_embedding(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
output = torch.empty_like(x)
if x.dim() > 3:
bsz, num_tokens, num_heads, head_size = x.shape
else:
num_tokens, num_heads, head_size = x.shape
bsz = 1
assert head_size % 2 == 0, "head_size must be divisible by 2"
x_reshaped = x.view(-1, head_size)
output_reshaped = output.view(-1, head_size)
# num_tokens per head, 1 token per block
grid = (bsz * num_tokens * num_heads,)
if interleaved and cos.shape[-1] == head_size:
cos = cos[..., ::2].contiguous()
sin = sin[..., ::2].contiguous()
else:
cos = cos.contiguous()
sin = sin.contiguous()
_rotary_embedding_kernel[grid](
output_reshaped,
x_reshaped,
cos,
sin,
num_heads,
head_size,
num_tokens,
x_reshaped.stride(0),
cos.stride(0),
sin.stride(0),
interleaved,
)
return output
# RMSNorm-fp32
def maybe_contiguous_lastdim(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
def maybe_contiguous(x):
return x.contiguous() if x is not None else None
def triton_autotune_configs():
# Return configs with a valid warp count for the current device
configs = []
# Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
max_threads_per_block = 1024
# Default to warp size 32 if not defined by device
warp_size = getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32)
# Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
return [triton.Config({}, num_warps=warp_count) for warp_count in [1, 2, 4, 8, 16, 32] if warp_count * warp_size <= max_threads_per_block]
# return [triton.Config({}, num_warps=8)]
# Copied from flash-attn
@triton.autotune(
configs=triton_autotune_configs(),
key=[
"N",
"HAS_RESIDUAL",
"STORE_RESIDUAL_OUT",
"IS_RMS_NORM",
"HAS_BIAS",
"HAS_WEIGHT",
"HAS_X1",
"HAS_W1",
"HAS_B1",
],
)
# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
# @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
# @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
# @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
@triton.jit
def _layer_norm_fwd_1pass_kernel(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
RESIDUAL, # pointer to the residual
X1,
W1,
B1,
Y1,
RESIDUAL_OUT, # pointer to the residual
ROWSCALE,
SEEDS, # Dropout seeds for each row
DROPOUT_MASK,
DROPOUT_MASK1,
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_res_row,
stride_res_out_row,
stride_x1_row,
stride_y1_row,
M, # number of rows in X
N, # number of columns in X
eps, # epsilon to avoid division by zero
dropout_p, # Dropout probability
zero_centered_weight, # If true, add 1.0 to the weight
IS_RMS_NORM: tl.constexpr,
BLOCK_N: tl.constexpr,
HAS_RESIDUAL: tl.constexpr,
STORE_RESIDUAL_OUT: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr,
HAS_DROPOUT: tl.constexpr,
STORE_DROPOUT_MASK: tl.constexpr,
HAS_ROWSCALE: tl.constexpr,
HAS_X1: tl.constexpr,
HAS_W1: tl.constexpr,
HAS_B1: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X += row * stride_x_row
Y += row * stride_y_row
if HAS_RESIDUAL:
RESIDUAL += row * stride_res_row
if STORE_RESIDUAL_OUT:
RESIDUAL_OUT += row * stride_res_out_row
if HAS_X1:
X1 += row * stride_x1_row
if HAS_W1:
Y1 += row * stride_y1_row
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_ROWSCALE:
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
x *= rowscale
if HAS_DROPOUT:
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
if STORE_DROPOUT_MASK:
tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
if HAS_X1:
x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_ROWSCALE:
rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
x1 *= rowscale
if HAS_DROPOUT:
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
keep_mask = tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
if STORE_DROPOUT_MASK:
tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N)
x += x1
if HAS_RESIDUAL:
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
x += residual
if STORE_RESIDUAL_OUT:
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
mask = cols < N
if HAS_WEIGHT:
w = tl.load(W + cols, mask=mask).to(tl.float32)
if zero_centered_weight:
w += 1.0
if HAS_BIAS:
b = tl.load(B + cols, mask=mask).to(tl.float32)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
if HAS_WEIGHT:
y = x_hat * w + b if HAS_BIAS else x_hat * w
else:
y = x_hat + b if HAS_BIAS else x_hat
# Write output
tl.store(Y + cols, y, mask=mask)
if HAS_W1:
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
if zero_centered_weight:
w1 += 1.0
if HAS_B1:
b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
tl.store(Y1 + cols, y1, mask=mask)
def _layer_norm_fwd(
x: Tensor,
weight: Tensor,
bias: Tensor,
eps: float,
residual: Optional[Tensor] = None,
x1: Optional[Tensor] = None,
weight1: Optional[Tensor] = None,
bias1: Optional[Tensor] = None,
dropout_p: float = 0.0,
rowscale: Optional[Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
residual_dtype: Optional[torch.dtype] = None,
zero_centered_weight: bool = False,
is_rms_norm: bool = False,
return_dropout_mask: bool = False,
out: Optional[Tensor] = None,
residual_out: Optional[Tensor] = None,
) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
# Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library
# and torch.compile unhappy. Also allocate memory for out and residual_out if they are None
# so that _layer_norm_fwd_impl doesn't have to return them.
if out is None:
out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
if residual is not None:
residual_dtype = residual.dtype
if residual_out is None and (residual is not None or (residual_dtype is not None and residual_dtype != x.dtype) or dropout_p > 0.0 or rowscale is not None or x1 is not None):
residual_out = torch.empty_like(x, dtype=residual_dtype if residual_dtype is not None else x.dtype)
else:
residual_out = None
y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl(
x,
weight,
bias,
eps,
out,
residual=residual,
x1=x1,
weight1=weight1,
bias1=bias1,
dropout_p=dropout_p,
rowscale=rowscale,
zero_centered_weight=zero_centered_weight,
is_rms_norm=is_rms_norm,
return_dropout_mask=return_dropout_mask,
residual_out=residual_out,
)
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
if residual_out is None:
residual_out = x
return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1
# [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema
# since we're returning a tuple of tensors
def _layer_norm_fwd_impl(
x: Tensor,
weight: Optional[Tensor],
bias: Tensor,
eps: float,
out: Tensor,
residual: Optional[Tensor] = None,
x1: Optional[Tensor] = None,
weight1: Optional[Tensor] = None,
bias1: Optional[Tensor] = None,
dropout_p: float = 0.0,
rowscale: Optional[Tensor] = None,
zero_centered_weight: bool = False,
is_rms_norm: bool = False,
return_dropout_mask: bool = False,
residual_out: Optional[Tensor] = None,
) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
M, N = x.shape
assert x.stride(-1) == 1
if residual is not None:
assert residual.stride(-1) == 1
assert residual.shape == (M, N)
if weight is not None:
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
if x1 is not None:
assert x1.shape == x.shape
assert rowscale is None
assert x1.stride(-1) == 1
if weight1 is not None:
assert weight1.shape == (N,)
assert weight1.stride(-1) == 1
if bias1 is not None:
assert bias1.shape == (N,)
assert bias1.stride(-1) == 1
if rowscale is not None:
assert rowscale.is_contiguous()
assert rowscale.shape == (M,)
assert out.shape == x.shape
assert out.stride(-1) == 1
if residual_out is not None:
assert residual_out.shape == x.shape
assert residual_out.stride(-1) == 1
if weight1 is not None:
y1 = torch.empty_like(out)
assert y1.stride(-1) == 1
else:
y1 = None
mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
if dropout_p > 0.0:
seeds = torch.randint(2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64)
else:
seeds = None
if return_dropout_mask and dropout_p > 0.0:
dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool)
if x1 is not None:
dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool)
else:
dropout_mask1 = None
else:
dropout_mask, dropout_mask1 = None, None
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
with torch.cuda.device(x.device.index):
torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)](
x,
out,
weight if weight is not None else x, # unused when HAS_WEIGHT == False
bias,
residual,
x1,
weight1,
bias1,
y1,
residual_out,
rowscale,
seeds,
dropout_mask,
dropout_mask1,
mean,
rstd,
x.stride(0),
out.stride(0),
residual.stride(0) if residual is not None else 0,
residual_out.stride(0) if residual_out is not None else 0,
x1.stride(0) if x1 is not None else 0,
y1.stride(0) if y1 is not None else 0,
M,
N,
eps,
dropout_p,
# Passing bool make torch inductor very unhappy since it then tries to compare to int_max
int(zero_centered_weight),
is_rms_norm,
BLOCK_N,
residual is not None,
residual_out is not None,
weight is not None,
bias is not None,
dropout_p > 0.0,
dropout_mask is not None,
rowscale is not None,
HAS_X1=x1 is not None,
HAS_W1=weight1 is not None,
HAS_B1=bias1 is not None,
)
return y1, mean, rstd, seeds, dropout_mask, dropout_mask1
class LayerNormFn:
@staticmethod
def forward(
x,
weight,
bias,
residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
prenorm=False,
residual_in_fp32=False,
zero_centered_weight=False,
is_rms_norm=False,
return_dropout_mask=False,
out_dtype=None,
out=None,
residual_out=None,
):
x_shape_og = x.shape
# reshape input data into 2D tensor
x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1]))
if residual is not None:
assert residual.shape == x_shape_og
residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1]))
if x1 is not None:
assert x1.shape == x_shape_og
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1]))
# weight can be None when elementwise_affine=False for LayerNorm
if weight is not None:
weight = weight.contiguous()
bias = maybe_contiguous(bias)
weight1 = maybe_contiguous(weight1)
bias1 = maybe_contiguous(bias1)
if rowscale is not None:
rowscale = rowscale.reshape(-1).contiguous()
residual_dtype = residual.dtype if residual is not None else (torch.float32 if residual_in_fp32 else None)
if out is not None:
out = out.reshape(-1, out.shape[-1])
if residual_out is not None:
residual_out = residual_out.reshape(-1, residual_out.shape[-1])
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
x,
weight,
bias,
eps,
residual,
x1,
weight1,
bias1,
dropout_p=dropout_p,
rowscale=rowscale,
out_dtype=out_dtype,
residual_dtype=residual_dtype,
zero_centered_weight=zero_centered_weight,
is_rms_norm=is_rms_norm,
return_dropout_mask=return_dropout_mask,
out=out,
residual_out=residual_out,
)
y = y.reshape(x_shape_og)
return y
def layer_norm_fn(
x,
weight,
bias,
residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
prenorm=False,
residual_in_fp32=False,
zero_centered_weight=False,
is_rms_norm=False,
return_dropout_mask=False,
out_dtype=None,
out=None,
residual_out=None,
):
return LayerNormFn.forward(
x,
weight,
bias,
residual,
x1,
weight1,
bias1,
eps,
dropout_p,
rowscale,
prenorm,
residual_in_fp32,
zero_centered_weight,
is_rms_norm,
return_dropout_mask,
out_dtype,
out,
residual_out,
)
@triton.jit
def _norm_infer_kernel(
X,
Y,
W,
B,
stride_x_row,
stride_y_row,
M,
N,
eps,
IS_RMS_NORM: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr,
BLOCK_N: tl.constexpr,
):
row = tl.program_id(0)
X += row * stride_x_row
Y += row * stride_y_row
if HAS_WEIGHT:
W += 0
if HAS_BIAS:
B += 0
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
if HAS_WEIGHT:
w = tl.load(W + cols, mask=cols < N, other=1.0).to(tl.float32)
y = x_hat * w
else:
y = x_hat
if HAS_BIAS:
b = tl.load(B + cols, mask=cols < N, other=0.0).to(tl.float32)
y += b
tl.store(Y + cols, y, mask=cols < N)
def norm_infer(
x: Tensor,
weight: Optional[Tensor],
bias: Optional[Tensor],
eps: float,
is_rms_norm: bool = False,
out: Optional[Tensor] = None,
):
M, N = x.shape
assert x.stride(-1) == 1
if weight is not None:
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.shape == (N,)
assert bias.stride(-1) == 1
if out is None:
out = torch.empty_like(x)
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
num_warps = min(max(BLOCK_N // 256, 1), 8)
_norm_infer_kernel[(M,)](
x,
out,
weight if weight is not None else x, # dummy when HAS_WEIGHT=False
bias if bias is not None else x, # dummy when HAS_BIAS=False
x.stride(0),
out.stride(0),
M,
N,
eps,
IS_RMS_NORM=is_rms_norm,
HAS_WEIGHT=weight is not None,
HAS_BIAS=bias is not None,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
)
return out
def rms_norm_fn(
x,
weight,
bias,
residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
prenorm=False,
residual_in_fp32=False,
zero_centered_weight=False,
return_dropout_mask=False,
out_dtype=None,
out=None,
residual_out=None,
):
return LayerNormFn.forward(
x,
weight,
bias,
residual,
x1,
weight1,
bias1,
eps,
dropout_p,
rowscale,
prenorm,
residual_in_fp32,
zero_centered_weight,
True,
return_dropout_mask,
out_dtype,
out,
residual_out,
)
import gc
import glob
import os
import torch
import torch.distributed as dist
import torch.nn.functional as F
from loguru import logger
from safetensors import safe_open
from lightx2v.models.networks.hunyuan_video.infer.feature_caching.transformer_infer import HunyuanTransformerInferTeaCaching, HunyuanVideo15TransformerInferMagCaching
from lightx2v.models.networks.hunyuan_video.infer.offload.transformer_infer import HunyuanVideo15OffloadTransformerInfer
from lightx2v.models.networks.hunyuan_video.infer.post_infer import HunyuanVideo15PostInfer
from lightx2v.models.networks.hunyuan_video.infer.pre_infer import HunyuanVideo15PreInfer
from lightx2v.models.networks.hunyuan_video.infer.transformer_infer import HunyuanVideo15TransformerInfer
from lightx2v.models.networks.hunyuan_video.weights.post_weights import HunyuanVideo15PostWeights
from lightx2v.models.networks.hunyuan_video.weights.pre_weights import HunyuanVideo15PreWeights
from lightx2v.models.networks.hunyuan_video.weights.transformer_weights import HunyuanVideo15TransformerWeights
from lightx2v.utils.custom_compiler import CompiledMethodsMixin
from lightx2v.utils.envs import *
class HunyuanVideo15Model(CompiledMethodsMixin):
def __init__(self, model_path, config, device):
super().__init__()
self.model_path = model_path
self.config = config
self.device = device
if self.config["seq_parallel"]:
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
else:
self.seq_p_group = None
self.cpu_offload = self.config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block")
self.remove_keys = ["byt5_in", "vision_in"]
self.dit_quantized = self.config.get("dit_quantized", False)
if self.dit_quantized:
assert self.config.get("dit_quant_scheme", "Default") in [
"Default-Force-FP32",
"fp8-vllm",
"int8-vllm",
"fp8-q8f",
"int8-q8f",
"fp8-b128-deepgemm",
"fp8-sgl",
"int8-sgl",
"int8-torchao",
"nvfp4",
"mxfp4",
"mxfp6-mxfp8",
"mxfp8",
]
self._init_infer_class()
self._init_weights()
self._init_infer()
def _init_infer_class(self):
self.pre_infer_class = HunyuanVideo15PreInfer
self.post_infer_class = HunyuanVideo15PostInfer
if self.config["feature_caching"] == "NoCaching":
self.transformer_infer_class = HunyuanVideo15TransformerInfer if not self.cpu_offload else HunyuanVideo15OffloadTransformerInfer
elif self.config["feature_caching"] == "Mag":
self.transformer_infer_class = HunyuanVideo15TransformerInferMagCaching
elif self.config["feature_caching"] == "Tea":
self.transformer_infer_class = HunyuanTransformerInferTeaCaching
else:
raise NotImplementedError
def _init_weights(self):
unified_dtype = GET_DTYPE() == GET_SENSITIVE_DTYPE()
sensitive_layer = {}
if not self.dit_quantized:
weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
else:
weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)
self.original_weight_dict = weight_dict
self.pre_weight = HunyuanVideo15PreWeights(self.config)
self.transformer_weights = HunyuanVideo15TransformerWeights(self.config)
self.post_weight = HunyuanVideo15PostWeights(self.config)
self._apply_weights()
def _apply_weights(self, weight_dict=None):
if weight_dict is not None:
self.original_weight_dict = weight_dict
del weight_dict
gc.collect()
# Load weights into containers
self.pre_weight.load(self.original_weight_dict)
self.transformer_weights.load(self.original_weight_dict)
del self.original_weight_dict
torch.cuda.empty_cache()
gc.collect()
def _init_infer(self):
self.pre_infer = self.pre_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config)
if hasattr(self.transformer_infer, "offload_manager"):
self.transformer_infer.offload_manager.init_cuda_buffer(self.transformer_weights.offload_block_buffers, self.transformer_weights.offload_phase_buffers)
def set_scheduler(self, scheduler):
self.scheduler = scheduler
self.pre_infer.set_scheduler(scheduler)
self.transformer_infer.set_scheduler(scheduler)
self.post_infer.set_scheduler(scheduler)
def _load_quant_ckpt(self, unified_dtype, sensitive_layer):
remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
if self.config.get("dit_quantized_ckpt", None):
safetensors_path = self.config["dit_quantized_ckpt"]
else:
safetensors_path = self.model_path
if os.path.isdir(safetensors_path):
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
else:
safetensors_files = [safetensors_path]
safetensors_path = os.path.dirname(safetensors_path)
weight_dict = {}
for safetensor_path in safetensors_files:
if self.config.get("adapter_model_path", None) is not None:
if self.config["adapter_model_path"] == safetensor_path:
continue
with safe_open(safetensor_path, framework="pt") as f:
logger.info(f"Loading weights from {safetensor_path}")
for k in f.keys():
if any(remove_key in k for remove_key in remove_keys):
continue
if f.get_tensor(k).dtype in [
torch.float16,
torch.bfloat16,
torch.float,
]:
if unified_dtype or all(s not in k for s in sensitive_layer):
weight_dict[k] = f.get_tensor(k).to(GET_DTYPE()).to(self.device)
else:
weight_dict[k] = f.get_tensor(k).to(GET_SENSITIVE_DTYPE()).to(self.device)
else:
weight_dict[k] = f.get_tensor(k).to(self.device)
if self.config.get("dit_quant_scheme", "Default") == "nvfp4":
calib_path = os.path.join(safetensors_path, "calib.pt")
logger.info(f"[CALIB] Loaded calibration data from: {calib_path}")
calib_data = torch.load(calib_path, map_location="cpu")
for k, v in calib_data["absmax"].items():
weight_dict[k.replace(".weight", ".input_absmax")] = v.to(self.device)
return weight_dict
def _load_ckpt(self, unified_dtype, sensitive_layer):
if self.config.get("dit_original_ckpt", None):
safetensors_path = self.config["dit_original_ckpt"]
else:
safetensors_path = self.config["transformer_model_path"]
if os.path.isdir(safetensors_path):
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
else:
safetensors_files = [safetensors_path]
weight_dict = {}
for file_path in safetensors_files:
if self.config.get("adapter_model_path", None) is not None:
if self.config["adapter_model_path"] == file_path:
continue
logger.info(f"Loading weights from {file_path}")
file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights)
return weight_dict
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
if self.device.type == "cuda" and dist.is_initialized():
device = torch.device("cuda:{}".format(dist.get_rank()))
else:
device = self.device
with safe_open(file_path, framework="pt", device=str(device)) as f:
return {
key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE()))
for key in f.keys()
if not any(remove_key in key for remove_key in remove_keys)
}
def to_cpu(self):
self.pre_weight.to_cpu()
self.transformer_weights.to_cpu()
def to_cuda(self):
self.pre_weight.to_cuda()
self.transformer_weights.to_cuda()
@torch.no_grad()
def infer(self, inputs):
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == 0 and "wan2.2_moe" not in self.config["model_cls"]:
self.to_cuda()
elif self.offload_granularity != "model":
self.pre_weight.to_cuda()
self.transformer_weights.non_block_weights_to_cuda()
if self.config["enable_cfg"]:
if self.config["cfg_parallel"]:
# ==================== CFG Parallel Processing ====================
cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
assert dist.get_world_size(cfg_p_group) == 2, "cfg_p_world_size must be equal to 2"
cfg_p_rank = dist.get_rank(cfg_p_group)
if cfg_p_rank == 0:
noise_pred = self._infer_cond_uncond(inputs, infer_condition=True).contiguous()
else:
noise_pred = self._infer_cond_uncond(inputs, infer_condition=False).contiguous()
noise_pred_list = [torch.zeros_like(noise_pred) for _ in range(2)]
dist.all_gather(noise_pred_list, noise_pred, group=cfg_p_group)
noise_pred_cond = noise_pred_list[0] # cfg_p_rank == 0
noise_pred_uncond = noise_pred_list[1] # cfg_p_rank == 1
else:
# ==================== CFG Processing ====================
noise_pred_cond = self._infer_cond_uncond(inputs, infer_condition=True)
noise_pred_uncond = self._infer_cond_uncond(inputs, infer_condition=False)
self.scheduler.noise_pred = noise_pred_uncond + self.scheduler.sample_guide_scale * (noise_pred_cond - noise_pred_uncond)
else:
# ==================== No CFG ====================
self.scheduler.noise_pred = self._infer_cond_uncond(inputs, infer_condition=True)
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1 and "wan2.2_moe" not in self.config["model_cls"]:
self.to_cpu()
elif self.offload_granularity != "model":
self.pre_weight.to_cpu()
self.transformer_weights.non_block_weights_to_cpu()
@torch.no_grad()
def _infer_cond_uncond(self, inputs, infer_condition=True):
self.scheduler.infer_condition = infer_condition
pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs)
if self.config["seq_parallel"]:
pre_infer_out = self._seq_parallel_pre_process(pre_infer_out)
x = self.transformer_infer.infer(self.transformer_weights, pre_infer_out)
if self.config["seq_parallel"]:
x = self._seq_parallel_post_process(x)
noise_pred = self.post_infer.infer(x, pre_infer_out)[0]
return noise_pred
@torch.no_grad()
def _seq_parallel_pre_process(self, pre_infer_out):
seqlen = pre_infer_out.img.shape[1]
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)
padding_size = (world_size - (seqlen % world_size)) % world_size
if padding_size > 0:
pre_infer_out.img = F.pad(pre_infer_out.img, (0, 0, 0, padding_size))
pre_infer_out.img = torch.chunk(pre_infer_out.img, world_size, dim=1)[cur_rank]
return pre_infer_out
@torch.no_grad()
def _seq_parallel_post_process(self, x):
world_size = dist.get_world_size(self.seq_p_group)
gathered_x = [torch.empty_like(x) for _ in range(world_size)]
dist.all_gather(gathered_x, x, group=self.seq_p_group)
combined_output = torch.cat(gathered_x, dim=1)
return combined_output
from lightx2v.common.modules.weight_module import WeightModule
class HunyuanVideo15PostWeights(WeightModule):
def __init__(self, config):
super().__init__()
self.config = config
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
from lightx2v.utils.registry_factory import (
CONV3D_WEIGHT_REGISTER,
EMBEDDING_WEIGHT_REGISTER,
LN_WEIGHT_REGISTER,
MM_WEIGHT_REGISTER,
)
class HunyuanVideo15PreWeights(WeightModule):
def __init__(self, config):
super().__init__()
self.config = config
self.mm_type = config.get("dit_quant_scheme", "Default")
self.patch_size = config["patch_size"] # (1, 1, 1)
self.add_module(
"img_in",
CONV3D_WEIGHT_REGISTER["Default"](
"img_in.proj.weight",
"img_in.proj.bias",
stride=self.patch_size,
),
)
self.add_module(
"time_in_0",
MM_WEIGHT_REGISTER["Default"](
"time_in.mlp.0.weight",
"time_in.mlp.0.bias",
),
)
self.add_module(
"time_in_2",
MM_WEIGHT_REGISTER["Default"](
"time_in.mlp.2.weight",
"time_in.mlp.2.bias",
),
)
if self.config["is_sr_running"]:
self.add_module(
"time_r_in_0",
MM_WEIGHT_REGISTER["Default"](
"time_r_in.mlp.0.weight",
"time_r_in.mlp.0.bias",
),
)
self.add_module(
"time_r_in_2",
MM_WEIGHT_REGISTER["Default"](
"time_r_in.mlp.2.weight",
"time_r_in.mlp.2.bias",
),
)
self.add_module(
"txt_in_t_embedder_0",
MM_WEIGHT_REGISTER["Default"](
"txt_in.t_embedder.mlp.0.weight",
"txt_in.t_embedder.mlp.0.bias",
),
)
self.add_module(
"txt_in_t_embedder_2",
MM_WEIGHT_REGISTER["Default"](
"txt_in.t_embedder.mlp.2.weight",
"txt_in.t_embedder.mlp.2.bias",
),
)
self.add_module(
"txt_in_c_embedder_0",
MM_WEIGHT_REGISTER["Default"](
"txt_in.c_embedder.linear_1.weight",
"txt_in.c_embedder.linear_1.bias",
),
)
self.add_module(
"txt_in_c_embedder_2",
MM_WEIGHT_REGISTER["Default"](
"txt_in.c_embedder.linear_2.weight",
"txt_in.c_embedder.linear_2.bias",
),
)
self.add_module(
"txt_in_input_embedder",
MM_WEIGHT_REGISTER["Default"](
"txt_in.input_embedder.weight",
"txt_in.input_embedder.bias",
),
)
self.add_module(
"individual_token_refiner",
WeightModuleList(
[
IndividualTokenRefinerBlock(
i,
self.mm_type,
self.config,
"txt_in.individual_token_refiner.blocks",
)
for i in range(2) # 2 blocks
]
),
)
self.add_module(
"cond_type_embedding",
EMBEDDING_WEIGHT_REGISTER["Default"](
"cond_type_embedding.weight",
),
)
class IndividualTokenRefinerBlock(WeightModule):
def __init__(self, block_idx, mm_type, config, block_prefix):
super().__init__()
self.config = config
self.mm_type = mm_type
self.add_module(
"norm1",
LN_WEIGHT_REGISTER["Default"](f"{block_prefix}.{block_idx}.norm1.weight", f"{block_prefix}.{block_idx}.norm1.bias"),
)
self.add_module(
"self_attn_qkv",
MM_WEIGHT_REGISTER["Default"](f"{block_prefix}.{block_idx}.self_attn_qkv.weight", f"{block_prefix}.{block_idx}.self_attn_qkv.bias"),
)
self.add_module(
"self_attn_proj",
MM_WEIGHT_REGISTER["Default"](f"{block_prefix}.{block_idx}.self_attn_proj.weight", f"{block_prefix}.{block_idx}.self_attn_proj.bias"),
)
self.add_module(
"norm2",
LN_WEIGHT_REGISTER["Default"](f"{block_prefix}.{block_idx}.norm2.weight", f"{block_prefix}.{block_idx}.norm2.bias"),
)
self.add_module(
"mlp_fc1",
MM_WEIGHT_REGISTER["Default"](f"{block_prefix}.{block_idx}.mlp.fc1.weight", f"{block_prefix}.{block_idx}.mlp.fc1.bias"),
)
self.add_module(
"mlp_fc2",
MM_WEIGHT_REGISTER["Default"](f"{block_prefix}.{block_idx}.mlp.fc2.weight", f"{block_prefix}.{block_idx}.mlp.fc2.bias"),
)
self.add_module(
"adaLN_modulation",
MM_WEIGHT_REGISTER["Default"](f"{block_prefix}.{block_idx}.adaLN_modulation.1.weight", f"{block_prefix}.{block_idx}.adaLN_modulation.1.bias"),
)
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
from lightx2v.utils.registry_factory import (
ATTN_WEIGHT_REGISTER,
LN_WEIGHT_REGISTER,
MM_WEIGHT_REGISTER,
RMS_WEIGHT_REGISTER,
)
class HunyuanVideo15TransformerWeights(WeightModule):
def __init__(self, config):
super().__init__()
self.config = config
self.task = config["task"]
self.mm_type = config.get("dit_quant_scheme", "Default")
self.ln_type = config.get("ln_type", "Triton")
self.rms_type = config.get("rms_type", "sgl-kernel")
self.double_blocks_num = config["mm_double_blocks_depth"]
self.register_offload_buffers(config)
self.add_module("double_blocks", WeightModuleList([MMDoubleStreamBlock(i, self.task, self.config, block_prefix="double_blocks") for i in range(self.double_blocks_num)]))
self.add_module("final_layer", FinalLayerWeights(self.config))
def register_offload_buffers(self, config):
if config["cpu_offload"]:
if config.get("offload_granularity", "block") == "block":
self.offload_blocks_num = 2
self.offload_block_buffers = WeightModuleList(
[
MMDoubleStreamBlock(
i,
self.task,
self.config,
"double_blocks",
True,
)
for i in range(self.offload_blocks_num)
]
)
self.add_module("offload_block_buffers", self.offload_block_buffers)
self.offload_phase_buffers = None
def non_block_weights_to_cuda(self):
self.final_layer.to_cuda()
def non_block_weights_to_cpu(self):
self.final_layer.to_cpu()
class MMDoubleStreamBlock(WeightModule):
def __init__(self, block_index, task, config, block_prefix="double_blocks", is_offload_buffer=False):
super().__init__()
self.block_index = block_index
self.task = task
self.config = config
self.is_offload_buffer = is_offload_buffer
self.lazy_load = False
self.lazy_load_file = None
self.add_module(
"img_branch",
MMDoubleStreamBlockImgBranch(block_index, task, config, block_prefix, is_offload_buffer),
)
self.add_module(
"txt_branch",
MMDoubleStreamBlockTxtBranch(block_index, task, config, block_prefix, is_offload_buffer),
)
attention_weights_cls = ATTN_WEIGHT_REGISTER[self.config["attn_type"]]
self.add_module("self_attention", attention_weights_cls())
if self.config["seq_parallel"]:
self.add_module(
"self_attention_parallel",
ATTN_WEIGHT_REGISTER[self.config["parallel"].get("seq_p_attn_type", "ulysses")](),
)
class MMDoubleStreamBlockImgBranch(WeightModule):
def __init__(self, block_index, task, config, block_prefix="double_blocks", is_offload_buffer=False):
super().__init__()
self.block_index = block_index
self.task = task
self.config = config
self.lazy_load = False
self.lazy_load_file = None
self.mm_type = config.get("dit_quant_scheme", "Default")
self.ln_type = config.get("ln_type", "Triton")
self.rms_type = config.get("rms_type", "sgl-kernel")
self.add_module(
"img_mod",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.img_mod.linear.weight",
f"{block_prefix}.{self.block_index}.img_mod.linear.bias",
is_offload_buffer,
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"img_norm1",
LN_WEIGHT_REGISTER[self.ln_type](
None,
None,
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"img_attn_q",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.img_attn_q.weight",
f"{block_prefix}.{self.block_index}.img_attn_q.bias",
is_offload_buffer,
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"img_attn_k",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.img_attn_k.weight",
f"{block_prefix}.{self.block_index}.img_attn_k.bias",
is_offload_buffer,
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"img_attn_v",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.img_attn_v.weight",
f"{block_prefix}.{self.block_index}.img_attn_v.bias",
is_offload_buffer,
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"img_attn_q_norm",
RMS_WEIGHT_REGISTER[self.rms_type](
f"{block_prefix}.{self.block_index}.img_attn_q_norm.weight",
is_offload_buffer,
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"img_attn_k_norm",
RMS_WEIGHT_REGISTER[self.rms_type](
f"{block_prefix}.{self.block_index}.img_attn_k_norm.weight",
is_offload_buffer,
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"img_attn_proj",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.img_attn_proj.weight",
f"{block_prefix}.{self.block_index}.img_attn_proj.bias",
is_offload_buffer,
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"img_norm2",
LN_WEIGHT_REGISTER[self.ln_type](
None,
None,
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"img_mlp_fc1",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.img_mlp.fc1.weight",
f"{block_prefix}.{self.block_index}.img_mlp.fc1.bias",
is_offload_buffer,
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"img_mlp_fc2",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.img_mlp.fc2.weight",
f"{block_prefix}.{self.block_index}.img_mlp.fc2.bias",
is_offload_buffer,
self.lazy_load,
self.lazy_load_file,
),
)
class MMDoubleStreamBlockTxtBranch(WeightModule):
def __init__(self, block_index, task, config, block_prefix="double_blocks", is_offload_buffer=False):
super().__init__()
self.block_index = block_index
self.task = task
self.config = config
self.lazy_load = False
self.lazy_load_file = None
self.mm_type = config.get("dit_quant_scheme", "Default")
self.ln_type = config.get("ln_type", "Triton")
self.rms_type = config.get("rms_type", "sgl-kernel")
self.add_module(
"txt_mod",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.txt_mod.linear.weight",
f"{block_prefix}.{self.block_index}.txt_mod.linear.bias",
is_offload_buffer,
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"txt_norm1",
LN_WEIGHT_REGISTER[self.ln_type](
None,
None,
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"txt_attn_q",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.txt_attn_q.weight",
f"{block_prefix}.{self.block_index}.txt_attn_q.bias",
is_offload_buffer,
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"txt_attn_k",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.txt_attn_k.weight",
f"{block_prefix}.{self.block_index}.txt_attn_k.bias",
is_offload_buffer,
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"txt_attn_v",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.txt_attn_v.weight",
f"{block_prefix}.{self.block_index}.txt_attn_v.bias",
is_offload_buffer,
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"txt_attn_q_norm",
RMS_WEIGHT_REGISTER[self.rms_type](
f"{block_prefix}.{self.block_index}.txt_attn_q_norm.weight",
is_offload_buffer,
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"txt_attn_k_norm",
RMS_WEIGHT_REGISTER[self.rms_type](
f"{block_prefix}.{self.block_index}.txt_attn_k_norm.weight",
is_offload_buffer,
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"txt_attn_proj",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.txt_attn_proj.weight",
f"{block_prefix}.{self.block_index}.txt_attn_proj.bias",
is_offload_buffer,
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"txt_norm2",
LN_WEIGHT_REGISTER[self.ln_type](
None,
None,
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"txt_mlp_fc1",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.txt_mlp.fc1.weight",
f"{block_prefix}.{self.block_index}.txt_mlp.fc1.bias",
is_offload_buffer,
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"txt_mlp_fc2",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.txt_mlp.fc2.weight",
f"{block_prefix}.{self.block_index}.txt_mlp.fc2.bias",
is_offload_buffer,
self.lazy_load,
self.lazy_load_file,
),
)
class FinalLayerWeights(WeightModule):
def __init__(self, config):
super().__init__()
self.config = config
self.lazy_load = False
self.lazy_load_file = None
self.mm_type = config.get("dit_quant_scheme", "Default")
self.ln_type = config.get("ln_type", "Triton")
self.add_module(
"adaLN_modulation",
MM_WEIGHT_REGISTER["Default"](
"final_layer.adaLN_modulation.1.weight",
"final_layer.adaLN_modulation.1.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"linear",
MM_WEIGHT_REGISTER["Default"](
"final_layer.linear.weight",
"final_layer.linear.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"norm_final",
LN_WEIGHT_REGISTER[self.ln_type](
None,
None,
self.lazy_load,
self.lazy_load_file,
),
)
......@@ -22,8 +22,9 @@ class WanAudioModel(WanModel):
def __init__(self, model_path, config, device):
self.config = config
super().__init__(model_path, config, device)
self._load_adapter_ckpt()
self.run_device = self.config.get("run_device", "cuda")
super().__init__(model_path, config, device)
def _load_adapter_ckpt(self):
if self.config.get("adapter_model_path", None) is None:
......@@ -50,7 +51,7 @@ class WanAudioModel(WanModel):
if not adapter_offload:
if not dist.is_initialized() or not load_from_rank0:
for key in self.adapter_weights_dict:
self.adapter_weights_dict[key] = self.adapter_weights_dict[key].to(torch.device(self.device))
self.adapter_weights_dict[key] = self.adapter_weights_dict[key].to(torch.device(self.run_device))
def _init_infer_class(self):
super()._init_infer_class()
......
......@@ -21,7 +21,7 @@ class WanAudioPreInfer(WanPreInfer):
rope_params(1024, 2 * (d // 6)),
],
dim=1,
).to(self.device)
).to(torch.device(self.run_device))
self.freq_dim = config["freq_dim"]
self.dim = config["dim"]
self.rope_t_dim = d // 2 - 2 * (d // 6)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment