Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
91c5dd15
"test/vscode:/vscode.git/clone" did not exist on "17de02f98d8f28e5affec7c5ff8e28f110d0af42"
Commit
91c5dd15
authored
Apr 07, 2025
by
gushiqiao
Committed by
Yang Yong(雍洋)
Apr 08, 2025
Browse files
Support cpu offload for hunyuan
parent
2b0139fe
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
50 deletions
+17
-50
lightx2v/text2v/models/networks/hunyuan/infer/transformer_infer.py
...text2v/models/networks/hunyuan/infer/transformer_infer.py
+16
-48
lightx2v/text2v/models/networks/hunyuan/model.py
lightx2v/text2v/models/networks/hunyuan/model.py
+1
-1
lightx2v/text2v/models/networks/hunyuan/weights/transformer_weights.py
...2v/models/networks/hunyuan/weights/transformer_weights.py
+0
-1
No files found.
lightx2v/text2v/models/networks/hunyuan/infer/transformer_infer.py
View file @
91c5dd15
...
@@ -28,17 +28,13 @@ class HunyuanTransformerInfer:
...
@@ -28,17 +28,13 @@ class HunyuanTransformerInfer:
def
infer
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
def
infer
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
return
self
.
infer_func
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
)
return
self
.
infer_func
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
)
def
_infer_with_offload
(
def
_infer_with_offload
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
txt_seq_len
=
txt
.
shape
[
0
]
txt_seq_len
=
txt
.
shape
[
0
]
img_seq_len
=
img
.
shape
[
0
]
img_seq_len
=
img
.
shape
[
0
]
for
double_block_idx
in
range
(
self
.
double_blocks_num
):
for
double_block_idx
in
range
(
self
.
double_blocks_num
):
if
double_block_idx
==
0
:
if
double_block_idx
==
0
:
self
.
double_weights_stream_mgr
.
active_weights
[
self
.
double_weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
double_blocks_weights
[
0
]
0
]
=
weights
.
double_blocks_weights
[
0
]
self
.
double_weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
self
.
double_weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
with
torch
.
cuda
.
stream
(
self
.
double_weights_stream_mgr
.
compute_stream
):
with
torch
.
cuda
.
stream
(
self
.
double_weights_stream_mgr
.
compute_stream
):
...
@@ -53,18 +49,14 @@ class HunyuanTransformerInfer:
...
@@ -53,18 +49,14 @@ class HunyuanTransformerInfer:
)
)
if
double_block_idx
<
self
.
double_blocks_num
-
1
:
if
double_block_idx
<
self
.
double_blocks_num
-
1
:
self
.
double_weights_stream_mgr
.
prefetch_weights
(
self
.
double_weights_stream_mgr
.
prefetch_weights
(
double_block_idx
+
1
,
weights
.
double_blocks_weights
)
double_block_idx
+
1
,
weights
.
double_blocks_weights
)
self
.
double_weights_stream_mgr
.
swap_weights
()
self
.
double_weights_stream_mgr
.
swap_weights
()
x
=
torch
.
cat
((
img
,
txt
),
0
)
x
=
torch
.
cat
((
img
,
txt
),
0
)
for
single_block_idx
in
range
(
self
.
single_blocks_num
):
for
single_block_idx
in
range
(
self
.
single_blocks_num
):
if
single_block_idx
==
0
:
if
single_block_idx
==
0
:
self
.
single_weights_stream_mgr
.
active_weights
[
self
.
single_weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
single_blocks_weights
[
0
]
0
]
=
weights
.
single_blocks_weights
[
0
]
self
.
single_weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
self
.
single_weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
with
torch
.
cuda
.
stream
(
self
.
single_weights_stream_mgr
.
compute_stream
):
with
torch
.
cuda
.
stream
(
self
.
single_weights_stream_mgr
.
compute_stream
):
x
=
self
.
infer_single_block
(
x
=
self
.
infer_single_block
(
...
@@ -77,9 +69,7 @@ class HunyuanTransformerInfer:
...
@@ -77,9 +69,7 @@ class HunyuanTransformerInfer:
freqs_cis
,
freqs_cis
,
)
)
if
single_block_idx
<
self
.
single_blocks_num
-
1
:
if
single_block_idx
<
self
.
single_blocks_num
-
1
:
self
.
single_weights_stream_mgr
.
prefetch_weights
(
self
.
single_weights_stream_mgr
.
prefetch_weights
(
single_block_idx
+
1
,
weights
.
single_blocks_weights
)
single_block_idx
+
1
,
weights
.
single_blocks_weights
)
self
.
single_weights_stream_mgr
.
swap_weights
()
self
.
single_weights_stream_mgr
.
swap_weights
()
img
=
x
[:
img_seq_len
,
...]
img
=
x
[:
img_seq_len
,
...]
...
@@ -116,9 +106,7 @@ class HunyuanTransformerInfer:
...
@@ -116,9 +106,7 @@ class HunyuanTransformerInfer:
img
=
x
[:
img_seq_len
,
...]
img
=
x
[:
img_seq_len
,
...]
return
img
,
vec
return
img
,
vec
def
infer_double_block
(
def
infer_double_block
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
vec_silu
=
torch
.
nn
.
functional
.
silu
(
vec
)
vec_silu
=
torch
.
nn
.
functional
.
silu
(
vec
)
img_mod_out
=
weights
.
img_mod
.
apply
(
vec_silu
)
img_mod_out
=
weights
.
img_mod
.
apply
(
vec_silu
)
...
@@ -141,12 +129,8 @@ class HunyuanTransformerInfer:
...
@@ -141,12 +129,8 @@ class HunyuanTransformerInfer:
txt_mod2_gate
,
txt_mod2_gate
,
)
=
txt_mod_out
.
chunk
(
6
,
dim
=-
1
)
)
=
txt_mod_out
.
chunk
(
6
,
dim
=-
1
)
img_q
,
img_k
,
img_v
=
self
.
infer_double_block_img_pre_atten
(
img_q
,
img_k
,
img_v
=
self
.
infer_double_block_img_pre_atten
(
weights
,
img
,
img_mod1_scale
,
img_mod1_shift
,
freqs_cis
)
weights
,
img
,
img_mod1_scale
,
img_mod1_shift
,
freqs_cis
txt_q
,
txt_k
,
txt_v
=
self
.
infer_double_block_txt_pre_atten
(
weights
,
txt
,
txt_mod1_scale
,
txt_mod1_shift
)
)
txt_q
,
txt_k
,
txt_v
=
self
.
infer_double_block_txt_pre_atten
(
weights
,
txt
,
txt_mod1_scale
,
txt_mod1_shift
)
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
0
)
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
0
)
k
=
torch
.
cat
((
img_k
,
txt_k
),
dim
=
0
)
k
=
torch
.
cat
((
img_k
,
txt_k
),
dim
=
0
)
...
@@ -197,18 +181,12 @@ class HunyuanTransformerInfer:
...
@@ -197,18 +181,12 @@ class HunyuanTransformerInfer:
)
)
return
img
,
txt
return
img
,
txt
def
infer_double_block_img_pre_atten
(
def
infer_double_block_img_pre_atten
(
self
,
weights
,
img
,
img_mod1_scale
,
img_mod1_shift
,
freqs_cis
):
self
,
weights
,
img
,
img_mod1_scale
,
img_mod1_shift
,
freqs_cis
img_modulated
=
torch
.
nn
.
functional
.
layer_norm
(
img
,
(
img
.
shape
[
1
],),
None
,
None
,
1e-6
)
):
img_modulated
=
torch
.
nn
.
functional
.
layer_norm
(
img
,
(
img
.
shape
[
1
],),
None
,
None
,
1e-6
)
img_modulated
=
img_modulated
*
(
1
+
img_mod1_scale
)
+
img_mod1_shift
img_modulated
=
img_modulated
*
(
1
+
img_mod1_scale
)
+
img_mod1_shift
img_qkv
=
weights
.
img_attn_qkv
.
apply
(
img_modulated
)
img_qkv
=
weights
.
img_attn_qkv
.
apply
(
img_modulated
)
img_q
,
img_k
,
img_v
=
rearrange
(
img_q
,
img_k
,
img_v
=
rearrange
(
img_qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
img_qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
img_q
=
weights
.
img_attn_q_norm
.
apply
(
img_q
)
img_q
=
weights
.
img_attn_q_norm
.
apply
(
img_q
)
img_k
=
weights
.
img_attn_k_norm
.
apply
(
img_k
)
img_k
=
weights
.
img_attn_k_norm
.
apply
(
img_k
)
...
@@ -216,18 +194,12 @@ class HunyuanTransformerInfer:
...
@@ -216,18 +194,12 @@ class HunyuanTransformerInfer:
img_q
,
img_k
=
apply_rotary_emb
(
img_q
,
img_k
,
freqs_cis
)
img_q
,
img_k
=
apply_rotary_emb
(
img_q
,
img_k
,
freqs_cis
)
return
img_q
,
img_k
,
img_v
return
img_q
,
img_k
,
img_v
def
infer_double_block_txt_pre_atten
(
def
infer_double_block_txt_pre_atten
(
self
,
weights
,
txt
,
txt_mod1_scale
,
txt_mod1_shift
):
self
,
weights
,
txt
,
txt_mod1_scale
,
txt_mod1_shift
txt_modulated
=
torch
.
nn
.
functional
.
layer_norm
(
txt
,
(
txt
.
shape
[
1
],),
None
,
None
,
1e-6
)
):
txt_modulated
=
torch
.
nn
.
functional
.
layer_norm
(
txt
,
(
txt
.
shape
[
1
],),
None
,
None
,
1e-6
)
txt_modulated
=
txt_modulated
*
(
1
+
txt_mod1_scale
)
+
txt_mod1_shift
txt_modulated
=
txt_modulated
*
(
1
+
txt_mod1_scale
)
+
txt_mod1_shift
txt_qkv
=
weights
.
txt_attn_qkv
.
apply
(
txt_modulated
)
txt_qkv
=
weights
.
txt_attn_qkv
.
apply
(
txt_modulated
)
txt_q
,
txt_k
,
txt_v
=
rearrange
(
txt_q
,
txt_k
,
txt_v
=
rearrange
(
txt_qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
txt_qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
txt_q
=
weights
.
txt_attn_q_norm
.
apply
(
txt_q
)
txt_q
=
weights
.
txt_attn_q_norm
.
apply
(
txt_q
)
txt_k
=
weights
.
txt_attn_k_norm
.
apply
(
txt_k
)
txt_k
=
weights
.
txt_attn_k_norm
.
apply
(
txt_k
)
...
@@ -279,9 +251,7 @@ class HunyuanTransformerInfer:
...
@@ -279,9 +251,7 @@ class HunyuanTransformerInfer:
txt
=
txt
+
out
txt
=
txt
+
out
return
txt
return
txt
def
infer_single_block
(
def
infer_single_block
(
self
,
weights
,
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
self
,
weights
,
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
out
=
torch
.
nn
.
functional
.
silu
(
vec
)
out
=
torch
.
nn
.
functional
.
silu
(
vec
)
out
=
weights
.
modulation
.
apply
(
out
)
out
=
weights
.
modulation
.
apply
(
out
)
mod_shift
,
mod_scale
,
mod_gate
=
out
.
chunk
(
3
,
dim
=-
1
)
mod_shift
,
mod_scale
,
mod_gate
=
out
.
chunk
(
3
,
dim
=-
1
)
...
@@ -291,9 +261,7 @@ class HunyuanTransformerInfer:
...
@@ -291,9 +261,7 @@ class HunyuanTransformerInfer:
x_mod
=
weights
.
linear1
.
apply
(
x_mod
)
x_mod
=
weights
.
linear1
.
apply
(
x_mod
)
qkv
,
mlp
=
torch
.
split
(
qkv
,
mlp
=
torch
.
split
(
x_mod
,
[
3
*
self
.
hidden_size
,
self
.
mlp_hidden_dim
],
dim
=-
1
)
x_mod
,
[
3
*
self
.
hidden_size
,
self
.
mlp_hidden_dim
],
dim
=-
1
)
q
,
k
,
v
=
rearrange
(
qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
q
,
k
,
v
=
rearrange
(
qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
...
...
lightx2v/text2v/models/networks/hunyuan/model.py
View file @
91c5dd15
lightx2v/text2v/models/networks/hunyuan/weights/transformer_weights.py
View file @
91c5dd15
...
@@ -136,7 +136,6 @@ class HunyuanTransformerSingleBlock:
...
@@ -136,7 +136,6 @@ class HunyuanTransformerSingleBlock:
mm_weight
.
set_config
(
self
.
config
[
"mm_config"
])
mm_weight
.
set_config
(
self
.
config
[
"mm_config"
])
mm_weight
.
load
(
weight_dict
)
mm_weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
def
to_cpu
(
self
):
for
mm_weight
in
self
.
weight_list
:
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment