Commit 9da774a7 authored by helloyongyang's avatar helloyongyang
Browse files

update hunyuan infer code

parent dcaefe63
...@@ -2,10 +2,11 @@ import torch ...@@ -2,10 +2,11 @@ import torch
from einops import rearrange from einops import rearrange
from .utils_bf16 import apply_rotary_emb from .utils_bf16 import apply_rotary_emb
from lightx2v.common.offload.manager import WeightAsyncStreamManager from lightx2v.common.offload.manager import WeightAsyncStreamManager
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
class HunyuanTransformerInfer: class HunyuanTransformerInfer(BaseTransformerInfer):
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.attention_type = config.get("attention_type", "flash_attn2") self.attention_type = config.get("attention_type", "flash_attn2")
...@@ -26,9 +27,6 @@ class HunyuanTransformerInfer: ...@@ -26,9 +27,6 @@ class HunyuanTransformerInfer:
else: else:
self.infer_func = self._infer_without_offload self.infer_func = self._infer_without_offload
def set_scheduler(self, scheduler):
self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE()) @torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None): def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
return self.infer_func(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num) return self.infer_func(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
...@@ -85,7 +83,7 @@ class HunyuanTransformerInfer: ...@@ -85,7 +83,7 @@ class HunyuanTransformerInfer:
img = x[:img_seq_len, ...] img = x[:img_seq_len, ...]
return img, vec return img, vec
def infer_double_block(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num): def infer_double_block_phase_1(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num):
vec_silu = torch.nn.functional.silu(vec) vec_silu = torch.nn.functional.silu(vec)
img_mod_out = weights.img_mod.apply(vec_silu) img_mod_out = weights.img_mod.apply(vec_silu)
...@@ -146,10 +144,136 @@ class HunyuanTransformerInfer: ...@@ -146,10 +144,136 @@ class HunyuanTransformerInfer:
) )
img_attn, txt_attn = attn[: img.shape[0]], attn[img.shape[0] :] img_attn, txt_attn = attn[: img.shape[0]], attn[img.shape[0] :]
img = self.infer_double_block_img_post_atten(
weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate, tr_img_mod1_gate, tr_img_mod2_shift, tr_img_mod2_scale, tr_img_mod2_gate, frist_frame_token_num img_out = weights.img_attn_proj.apply(img_attn)
txt_out = weights.txt_attn_proj.apply(txt_attn)
return (
img_out,
txt_out,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
tr_img_mod1_gate,
tr_img_mod2_shift,
tr_img_mod2_scale,
tr_img_mod2_gate,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
)
def infer_double_block_phase_2(
self,
weights,
img,
txt,
vec,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
token_replace_vec,
frist_frame_token_num,
img_out,
txt_out,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
tr_img_mod1_gate,
tr_img_mod2_shift,
tr_img_mod2_scale,
tr_img_mod2_gate,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
):
if tr_img_mod1_gate is not None:
x_zero = img_out[:frist_frame_token_num] * tr_img_mod1_gate
x_orig = img_out[frist_frame_token_num:] * img_mod1_gate
img_out = torch.concat((x_zero, x_orig), dim=0)
else:
img_out = img_out * img_mod1_gate
img = img + img_out
img_out = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6)
if tr_img_mod1_gate is not None:
x_zero = img_out[:frist_frame_token_num] * (1 + tr_img_mod2_scale) + tr_img_mod2_shift
x_orig = img_out[frist_frame_token_num:] * (1 + img_mod2_scale) + img_mod2_shift
img_out = torch.concat((x_zero, x_orig), dim=0)
else:
img_out = img_out * (1 + img_mod2_scale) + img_mod2_shift
img_out = weights.img_mlp_fc1.apply(img_out)
img_out = torch.nn.functional.gelu(img_out, approximate="tanh")
img_out = weights.img_mlp_fc2.apply(img_out)
txt_out = txt_out * txt_mod1_gate
txt = txt + txt_out
txt_out = torch.nn.functional.layer_norm(txt, (txt.shape[1],), None, None, 1e-6)
txt_out = txt_out * (1 + txt_mod2_scale) + txt_mod2_shift
txt_out = weights.txt_mlp_fc1.apply(txt_out)
txt_out = torch.nn.functional.gelu(txt_out, approximate="tanh")
txt_out = weights.txt_mlp_fc2.apply(txt_out)
return img, txt, img_out, txt_out, img_mod2_gate, txt_mod2_gate
def infer_double_block_phase_3(self, img_out, img_mod2_gate, img, txt_out, txt_mod2_gate, txt):
# img
img_out = img_out * img_mod2_gate
img = img + img_out
# txt
txt_out = txt_out * txt_mod2_gate
txt = txt + txt_out
return img, txt
def infer_double_block(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num):
(
img_out,
txt_out,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
tr_img_mod1_gate,
tr_img_mod2_shift,
tr_img_mod2_scale,
tr_img_mod2_gate,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = self.infer_double_block_phase_1(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
img, txt, img_out, txt_out, img_mod2_gate, txt_mod2_gate = self.infer_double_block_phase_2(
weights,
img,
txt,
vec,
cu_seqlens_qkv,
max_seqlen_qkv,
freqs_cis,
token_replace_vec,
frist_frame_token_num,
img_out,
txt_out,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
tr_img_mod1_gate,
tr_img_mod2_shift,
tr_img_mod2_scale,
tr_img_mod2_gate,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) )
txt = self.infer_double_block_txt_post_atten(weights, txt, txt_attn, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate) img, txt = self.infer_double_block_phase_3(img_out, img_mod2_gate, img, txt_out, txt_mod2_gate, txt)
return img, txt return img, txt
def infer_double_block_img_pre_atten(self, weights, img, img_mod1_scale, img_mod1_shift, tr_img_mod1_scale, tr_img_mod1_shift, frist_frame_token_num, freqs_cis): def infer_double_block_img_pre_atten(self, weights, img, img_mod1_scale, img_mod1_shift, tr_img_mod1_scale, tr_img_mod1_shift, frist_frame_token_num, freqs_cis):
...@@ -181,56 +305,7 @@ class HunyuanTransformerInfer: ...@@ -181,56 +305,7 @@ class HunyuanTransformerInfer:
txt_k = weights.txt_attn_k_norm.apply(txt_k) txt_k = weights.txt_attn_k_norm.apply(txt_k)
return txt_q, txt_k, txt_v return txt_q, txt_k, txt_v
def infer_double_block_img_post_atten( def infer_single_block_phase_1(self, weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
self, weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate, tr_img_mod1_gate, tr_img_mod2_shift, tr_img_mod2_scale, tr_img_mod2_gate, frist_frame_token_num
):
out = weights.img_attn_proj.apply(img_attn)
if tr_img_mod1_gate is not None:
x_zero = out[:frist_frame_token_num] * tr_img_mod1_gate
x_orig = out[frist_frame_token_num:] * img_mod1_gate
out = torch.concat((x_zero, x_orig), dim=0)
else:
out = out * img_mod1_gate
img = img + out
out = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6)
if tr_img_mod1_gate is not None:
x_zero = out[:frist_frame_token_num] * (1 + tr_img_mod2_scale) + tr_img_mod2_shift
x_orig = out[frist_frame_token_num:] * (1 + img_mod2_scale) + img_mod2_shift
out = torch.concat((x_zero, x_orig), dim=0)
else:
out = out * (1 + img_mod2_scale) + img_mod2_shift
out = weights.img_mlp_fc1.apply(out)
out = torch.nn.functional.gelu(out, approximate="tanh")
out = weights.img_mlp_fc2.apply(out)
out = out * img_mod2_gate
img = img + out
return img
def infer_double_block_txt_post_atten(
self,
weights,
txt,
txt_attn,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
):
out = weights.txt_attn_proj.apply(txt_attn)
out = out * txt_mod1_gate
txt = txt + out
out = torch.nn.functional.layer_norm(txt, (txt.shape[1],), None, None, 1e-6)
out = out * (1 + txt_mod2_scale) + txt_mod2_shift
out = weights.txt_mlp_fc1.apply(out)
out = torch.nn.functional.gelu(out, approximate="tanh")
out = weights.txt_mlp_fc2.apply(out)
out = out * txt_mod2_gate
txt = txt + out
return txt
def infer_single_block(self, weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
out = torch.nn.functional.silu(vec) out = torch.nn.functional.silu(vec)
out = weights.modulation.apply(out) out = weights.modulation.apply(out)
mod_shift, mod_scale, mod_gate = out.chunk(3, dim=-1) mod_shift, mod_scale, mod_gate = out.chunk(3, dim=-1)
...@@ -239,6 +314,8 @@ class HunyuanTransformerInfer: ...@@ -239,6 +314,8 @@ class HunyuanTransformerInfer:
token_replace_vec_out = torch.nn.functional.silu(token_replace_vec) token_replace_vec_out = torch.nn.functional.silu(token_replace_vec)
token_replace_vec_out = weights.modulation.apply(token_replace_vec_out) token_replace_vec_out = weights.modulation.apply(token_replace_vec_out)
tr_mod_shift, tr_mod_scale, tr_mod_gate = token_replace_vec_out.chunk(3, dim=-1) tr_mod_shift, tr_mod_scale, tr_mod_gate = token_replace_vec_out.chunk(3, dim=-1)
else:
tr_mod_shift, tr_mod_scale, tr_mod_gate = None, None, None
out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6) out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
if token_replace_vec is not None: if token_replace_vec is not None:
...@@ -289,7 +366,9 @@ class HunyuanTransformerInfer: ...@@ -289,7 +366,9 @@ class HunyuanTransformerInfer:
out = torch.nn.functional.gelu(mlp, approximate="tanh") out = torch.nn.functional.gelu(mlp, approximate="tanh")
out = torch.cat((attn, out), 1) out = torch.cat((attn, out), 1)
out = weights.linear2.apply(out) out = weights.linear2.apply(out)
return out, mod_gate, tr_mod_gate
def infer_single_block_phase_2(self, x, out, tr_mod_gate, mod_gate, token_replace_vec=None, frist_frame_token_num=None):
if token_replace_vec is not None: if token_replace_vec is not None:
x_zero = out[:frist_frame_token_num] * tr_mod_gate x_zero = out[:frist_frame_token_num] * tr_mod_gate
x_orig = out[frist_frame_token_num:] * mod_gate x_orig = out[frist_frame_token_num:] * mod_gate
...@@ -298,3 +377,8 @@ class HunyuanTransformerInfer: ...@@ -298,3 +377,8 @@ class HunyuanTransformerInfer:
out = out * mod_gate out = out * mod_gate
x = x + out x = x + out
return x return x
def infer_single_block(self, weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
out, mod_gate, tr_mod_gate = self.infer_single_block_phase_1(weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
x = self.infer_single_block_phase_2(x, out, tr_mod_gate, mod_gate, token_replace_vec, frist_frame_token_num)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment