Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
9da774a7
Commit
9da774a7
authored
Jun 30, 2025
by
helloyongyang
Browse files
update hunyuan infer code
parent
dcaefe63
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
142 additions
and
58 deletions
+142
-58
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
+142
-58
No files found.
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
View file @
9da774a7
...
@@ -2,10 +2,11 @@ import torch
...
@@ -2,10 +2,11 @@ import torch
from
einops
import
rearrange
from
einops
import
rearrange
from
.utils_bf16
import
apply_rotary_emb
from
.utils_bf16
import
apply_rotary_emb
from
lightx2v.common.offload.manager
import
WeightAsyncStreamManager
from
lightx2v.common.offload.manager
import
WeightAsyncStreamManager
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
class
HunyuanTransformerInfer
:
class
HunyuanTransformerInfer
(
BaseTransformerInfer
)
:
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
self
.
config
=
config
self
.
config
=
config
self
.
attention_type
=
config
.
get
(
"attention_type"
,
"flash_attn2"
)
self
.
attention_type
=
config
.
get
(
"attention_type"
,
"flash_attn2"
)
...
@@ -26,9 +27,6 @@ class HunyuanTransformerInfer:
...
@@ -26,9 +27,6 @@ class HunyuanTransformerInfer:
else
:
else
:
self
.
infer_func
=
self
.
_infer_without_offload
self
.
infer_func
=
self
.
_infer_without_offload
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
def
infer
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
return
self
.
infer_func
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
return
self
.
infer_func
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
...
@@ -85,7 +83,7 @@ class HunyuanTransformerInfer:
...
@@ -85,7 +83,7 @@ class HunyuanTransformerInfer:
img
=
x
[:
img_seq_len
,
...]
img
=
x
[:
img_seq_len
,
...]
return
img
,
vec
return
img
,
vec
def
infer_double_block
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
):
def
infer_double_block
_phase_1
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
):
vec_silu
=
torch
.
nn
.
functional
.
silu
(
vec
)
vec_silu
=
torch
.
nn
.
functional
.
silu
(
vec
)
img_mod_out
=
weights
.
img_mod
.
apply
(
vec_silu
)
img_mod_out
=
weights
.
img_mod
.
apply
(
vec_silu
)
...
@@ -146,10 +144,136 @@ class HunyuanTransformerInfer:
...
@@ -146,10 +144,136 @@ class HunyuanTransformerInfer:
)
)
img_attn
,
txt_attn
=
attn
[:
img
.
shape
[
0
]],
attn
[
img
.
shape
[
0
]
:]
img_attn
,
txt_attn
=
attn
[:
img
.
shape
[
0
]],
attn
[
img
.
shape
[
0
]
:]
img
=
self
.
infer_double_block_img_post_atten
(
weights
,
img
,
img_attn
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
frist_frame_token_num
img_out
=
weights
.
img_attn_proj
.
apply
(
img_attn
)
txt_out
=
weights
.
txt_attn_proj
.
apply
(
txt_attn
)
return
(
img_out
,
txt_out
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
def
infer_double_block_phase_2
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
,
img_out
,
txt_out
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
):
if
tr_img_mod1_gate
is
not
None
:
x_zero
=
img_out
[:
frist_frame_token_num
]
*
tr_img_mod1_gate
x_orig
=
img_out
[
frist_frame_token_num
:]
*
img_mod1_gate
img_out
=
torch
.
concat
((
x_zero
,
x_orig
),
dim
=
0
)
else
:
img_out
=
img_out
*
img_mod1_gate
img
=
img
+
img_out
img_out
=
torch
.
nn
.
functional
.
layer_norm
(
img
,
(
img
.
shape
[
1
],),
None
,
None
,
1e-6
)
if
tr_img_mod1_gate
is
not
None
:
x_zero
=
img_out
[:
frist_frame_token_num
]
*
(
1
+
tr_img_mod2_scale
)
+
tr_img_mod2_shift
x_orig
=
img_out
[
frist_frame_token_num
:]
*
(
1
+
img_mod2_scale
)
+
img_mod2_shift
img_out
=
torch
.
concat
((
x_zero
,
x_orig
),
dim
=
0
)
else
:
img_out
=
img_out
*
(
1
+
img_mod2_scale
)
+
img_mod2_shift
img_out
=
weights
.
img_mlp_fc1
.
apply
(
img_out
)
img_out
=
torch
.
nn
.
functional
.
gelu
(
img_out
,
approximate
=
"tanh"
)
img_out
=
weights
.
img_mlp_fc2
.
apply
(
img_out
)
txt_out
=
txt_out
*
txt_mod1_gate
txt
=
txt
+
txt_out
txt_out
=
torch
.
nn
.
functional
.
layer_norm
(
txt
,
(
txt
.
shape
[
1
],),
None
,
None
,
1e-6
)
txt_out
=
txt_out
*
(
1
+
txt_mod2_scale
)
+
txt_mod2_shift
txt_out
=
weights
.
txt_mlp_fc1
.
apply
(
txt_out
)
txt_out
=
torch
.
nn
.
functional
.
gelu
(
txt_out
,
approximate
=
"tanh"
)
txt_out
=
weights
.
txt_mlp_fc2
.
apply
(
txt_out
)
return
img
,
txt
,
img_out
,
txt_out
,
img_mod2_gate
,
txt_mod2_gate
def
infer_double_block_phase_3
(
self
,
img_out
,
img_mod2_gate
,
img
,
txt_out
,
txt_mod2_gate
,
txt
):
# img
img_out
=
img_out
*
img_mod2_gate
img
=
img
+
img_out
# txt
txt_out
=
txt_out
*
txt_mod2_gate
txt
=
txt
+
txt_out
return
img
,
txt
def
infer_double_block
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
):
(
img_out
,
txt_out
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
=
self
.
infer_double_block_phase_1
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
img
,
txt
,
img_out
,
txt_out
,
img_mod2_gate
,
txt_mod2_gate
=
self
.
infer_double_block_phase_2
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
,
img_out
,
txt_out
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
)
txt
=
self
.
infer_double_block_
txt_post_atten
(
weights
,
txt
,
txt_attn
,
txt
_mod
1
_gate
,
txt_mod2_shif
t
,
txt_mod2_
scal
e
,
txt
_mod2_gate
)
img
,
txt
=
self
.
infer_double_block_
phase_3
(
img_out
,
img
_mod
2
_gate
,
img
,
txt_ou
t
,
txt_mod2_
gat
e
,
txt
)
return
img
,
txt
return
img
,
txt
def
infer_double_block_img_pre_atten
(
self
,
weights
,
img
,
img_mod1_scale
,
img_mod1_shift
,
tr_img_mod1_scale
,
tr_img_mod1_shift
,
frist_frame_token_num
,
freqs_cis
):
def
infer_double_block_img_pre_atten
(
self
,
weights
,
img
,
img_mod1_scale
,
img_mod1_shift
,
tr_img_mod1_scale
,
tr_img_mod1_shift
,
frist_frame_token_num
,
freqs_cis
):
...
@@ -181,56 +305,7 @@ class HunyuanTransformerInfer:
...
@@ -181,56 +305,7 @@ class HunyuanTransformerInfer:
txt_k
=
weights
.
txt_attn_k_norm
.
apply
(
txt_k
)
txt_k
=
weights
.
txt_attn_k_norm
.
apply
(
txt_k
)
return
txt_q
,
txt_k
,
txt_v
return
txt_q
,
txt_k
,
txt_v
def
infer_double_block_img_post_atten
(
def
infer_single_block_phase_1
(
self
,
weights
,
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
self
,
weights
,
img
,
img_attn
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
tr_img_mod1_gate
,
tr_img_mod2_shift
,
tr_img_mod2_scale
,
tr_img_mod2_gate
,
frist_frame_token_num
):
out
=
weights
.
img_attn_proj
.
apply
(
img_attn
)
if
tr_img_mod1_gate
is
not
None
:
x_zero
=
out
[:
frist_frame_token_num
]
*
tr_img_mod1_gate
x_orig
=
out
[
frist_frame_token_num
:]
*
img_mod1_gate
out
=
torch
.
concat
((
x_zero
,
x_orig
),
dim
=
0
)
else
:
out
=
out
*
img_mod1_gate
img
=
img
+
out
out
=
torch
.
nn
.
functional
.
layer_norm
(
img
,
(
img
.
shape
[
1
],),
None
,
None
,
1e-6
)
if
tr_img_mod1_gate
is
not
None
:
x_zero
=
out
[:
frist_frame_token_num
]
*
(
1
+
tr_img_mod2_scale
)
+
tr_img_mod2_shift
x_orig
=
out
[
frist_frame_token_num
:]
*
(
1
+
img_mod2_scale
)
+
img_mod2_shift
out
=
torch
.
concat
((
x_zero
,
x_orig
),
dim
=
0
)
else
:
out
=
out
*
(
1
+
img_mod2_scale
)
+
img_mod2_shift
out
=
weights
.
img_mlp_fc1
.
apply
(
out
)
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
"tanh"
)
out
=
weights
.
img_mlp_fc2
.
apply
(
out
)
out
=
out
*
img_mod2_gate
img
=
img
+
out
return
img
def
infer_double_block_txt_post_atten
(
self
,
weights
,
txt
,
txt_attn
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
):
out
=
weights
.
txt_attn_proj
.
apply
(
txt_attn
)
out
=
out
*
txt_mod1_gate
txt
=
txt
+
out
out
=
torch
.
nn
.
functional
.
layer_norm
(
txt
,
(
txt
.
shape
[
1
],),
None
,
None
,
1e-6
)
out
=
out
*
(
1
+
txt_mod2_scale
)
+
txt_mod2_shift
out
=
weights
.
txt_mlp_fc1
.
apply
(
out
)
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
"tanh"
)
out
=
weights
.
txt_mlp_fc2
.
apply
(
out
)
out
=
out
*
txt_mod2_gate
txt
=
txt
+
out
return
txt
def
infer_single_block
(
self
,
weights
,
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
out
=
torch
.
nn
.
functional
.
silu
(
vec
)
out
=
torch
.
nn
.
functional
.
silu
(
vec
)
out
=
weights
.
modulation
.
apply
(
out
)
out
=
weights
.
modulation
.
apply
(
out
)
mod_shift
,
mod_scale
,
mod_gate
=
out
.
chunk
(
3
,
dim
=-
1
)
mod_shift
,
mod_scale
,
mod_gate
=
out
.
chunk
(
3
,
dim
=-
1
)
...
@@ -239,6 +314,8 @@ class HunyuanTransformerInfer:
...
@@ -239,6 +314,8 @@ class HunyuanTransformerInfer:
token_replace_vec_out
=
torch
.
nn
.
functional
.
silu
(
token_replace_vec
)
token_replace_vec_out
=
torch
.
nn
.
functional
.
silu
(
token_replace_vec
)
token_replace_vec_out
=
weights
.
modulation
.
apply
(
token_replace_vec_out
)
token_replace_vec_out
=
weights
.
modulation
.
apply
(
token_replace_vec_out
)
tr_mod_shift
,
tr_mod_scale
,
tr_mod_gate
=
token_replace_vec_out
.
chunk
(
3
,
dim
=-
1
)
tr_mod_shift
,
tr_mod_scale
,
tr_mod_gate
=
token_replace_vec_out
.
chunk
(
3
,
dim
=-
1
)
else
:
tr_mod_shift
,
tr_mod_scale
,
tr_mod_gate
=
None
,
None
,
None
out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
)
out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
)
if
token_replace_vec
is
not
None
:
if
token_replace_vec
is
not
None
:
...
@@ -289,7 +366,9 @@ class HunyuanTransformerInfer:
...
@@ -289,7 +366,9 @@ class HunyuanTransformerInfer:
out
=
torch
.
nn
.
functional
.
gelu
(
mlp
,
approximate
=
"tanh"
)
out
=
torch
.
nn
.
functional
.
gelu
(
mlp
,
approximate
=
"tanh"
)
out
=
torch
.
cat
((
attn
,
out
),
1
)
out
=
torch
.
cat
((
attn
,
out
),
1
)
out
=
weights
.
linear2
.
apply
(
out
)
out
=
weights
.
linear2
.
apply
(
out
)
return
out
,
mod_gate
,
tr_mod_gate
def
infer_single_block_phase_2
(
self
,
x
,
out
,
tr_mod_gate
,
mod_gate
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
if
token_replace_vec
is
not
None
:
if
token_replace_vec
is
not
None
:
x_zero
=
out
[:
frist_frame_token_num
]
*
tr_mod_gate
x_zero
=
out
[:
frist_frame_token_num
]
*
tr_mod_gate
x_orig
=
out
[
frist_frame_token_num
:]
*
mod_gate
x_orig
=
out
[
frist_frame_token_num
:]
*
mod_gate
...
@@ -298,3 +377,8 @@ class HunyuanTransformerInfer:
...
@@ -298,3 +377,8 @@ class HunyuanTransformerInfer:
out
=
out
*
mod_gate
out
=
out
*
mod_gate
x
=
x
+
out
x
=
x
+
out
return
x
return
x
def
infer_single_block
(
self
,
weights
,
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
out
,
mod_gate
,
tr_mod_gate
=
self
.
infer_single_block_phase_1
(
weights
,
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
x
=
self
.
infer_single_block_phase_2
(
x
,
out
,
tr_mod_gate
,
mod_gate
,
token_replace_vec
,
frist_frame_token_num
)
return
x
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment