Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
2b0139fe
"...resnet50_tensorflow.git" did not exist on "ff17f5ddd812357f25440991013f7152f992b9d4"
Commit
2b0139fe
authored
Apr 07, 2025
by
gushiqiao
Committed by
Yang Yong(雍洋)
Apr 08, 2025
Browse files
Support cpu offload for hunyuan
parent
83c5f3b8
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
193 additions
and
29 deletions
+193
-29
lightx2v/__main__.py
lightx2v/__main__.py
+1
-1
lightx2v/text2v/models/networks/hunyuan/infer/transformer_infer.py
...text2v/models/networks/hunyuan/infer/transformer_infer.py
+156
-17
lightx2v/text2v/models/networks/hunyuan/model.py
lightx2v/text2v/models/networks/hunyuan/model.py
+9
-2
lightx2v/text2v/models/networks/hunyuan/weights/transformer_weights.py
...2v/models/networks/hunyuan/weights/transformer_weights.py
+27
-6
lightx2v/text2v/models/networks/wan/weights/transformer_weights.py
...text2v/models/networks/wan/weights/transformer_weights.py
+0
-3
No files found.
lightx2v/__main__.py
View file @
2b0139fe
...
@@ -41,7 +41,7 @@ def load_models(args, model_config):
...
@@ -41,7 +41,7 @@ def load_models(args, model_config):
text_encoder_1
=
TextEncoderHFLlamaModel
(
os
.
path
.
join
(
args
.
model_path
,
"text_encoder"
),
init_device
)
text_encoder_1
=
TextEncoderHFLlamaModel
(
os
.
path
.
join
(
args
.
model_path
,
"text_encoder"
),
init_device
)
text_encoder_2
=
TextEncoderHFClipModel
(
os
.
path
.
join
(
args
.
model_path
,
"text_encoder_2"
),
init_device
)
text_encoder_2
=
TextEncoderHFClipModel
(
os
.
path
.
join
(
args
.
model_path
,
"text_encoder_2"
),
init_device
)
text_encoders
=
[
text_encoder_1
,
text_encoder_2
]
text_encoders
=
[
text_encoder_1
,
text_encoder_2
]
model
=
HunyuanModel
(
args
.
model_path
,
model_config
)
model
=
HunyuanModel
(
args
.
model_path
,
model_config
,
device
=
init_device
)
vae_model
=
VideoEncoderKLCausal3DModel
(
args
.
model_path
,
dtype
=
torch
.
float16
,
device
=
init_device
)
vae_model
=
VideoEncoderKLCausal3DModel
(
args
.
model_path
,
dtype
=
torch
.
float16
,
device
=
init_device
)
elif
args
.
model_cls
==
"wan2.1"
:
elif
args
.
model_cls
==
"wan2.1"
:
...
...
lightx2v/text2v/models/networks/hunyuan/infer/transformer_infer.py
View file @
2b0139fe
...
@@ -2,6 +2,7 @@ import torch
...
@@ -2,6 +2,7 @@ import torch
from
einops
import
rearrange
from
einops
import
rearrange
from
lightx2v.attentions
import
attention
from
lightx2v.attentions
import
attention
from
.utils_bf16
import
apply_rotary_emb
from
.utils_bf16
import
apply_rotary_emb
from
lightx2v.common.offload.manager
import
WeightStreamManager
class
HunyuanTransformerInfer
:
class
HunyuanTransformerInfer
:
...
@@ -14,26 +15,110 @@ class HunyuanTransformerInfer:
...
@@ -14,26 +15,110 @@ class HunyuanTransformerInfer:
self
.
hidden_size
=
3072
self
.
hidden_size
=
3072
self
.
mlp_hidden_dim
=
12288
self
.
mlp_hidden_dim
=
12288
self
.
parallel_attention
=
None
self
.
parallel_attention
=
None
if
self
.
config
[
"cpu_offload"
]:
self
.
double_weights_stream_mgr
=
WeightStreamManager
()
self
.
single_weights_stream_mgr
=
WeightStreamManager
()
self
.
infer_func
=
self
.
_infer_with_offload
else
:
self
.
infer_func
=
self
.
_infer_without_offload
def
set_scheduler
(
self
,
scheduler
):
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
def
infer
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
def
infer
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
return
self
.
infer_func
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
)
def
_infer_with_offload
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
txt_seq_len
=
txt
.
shape
[
0
]
img_seq_len
=
img
.
shape
[
0
]
for
double_block_idx
in
range
(
self
.
double_blocks_num
):
if
double_block_idx
==
0
:
self
.
double_weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
double_blocks_weights
[
0
]
self
.
double_weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
with
torch
.
cuda
.
stream
(
self
.
double_weights_stream_mgr
.
compute_stream
):
img
,
txt
=
self
.
infer_double_block
(
self
.
double_weights_stream_mgr
.
active_weights
[
0
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
)
if
double_block_idx
<
self
.
double_blocks_num
-
1
:
self
.
double_weights_stream_mgr
.
prefetch_weights
(
double_block_idx
+
1
,
weights
.
double_blocks_weights
)
self
.
double_weights_stream_mgr
.
swap_weights
()
x
=
torch
.
cat
((
img
,
txt
),
0
)
for
single_block_idx
in
range
(
self
.
single_blocks_num
):
if
single_block_idx
==
0
:
self
.
single_weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
single_blocks_weights
[
0
]
self
.
single_weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
with
torch
.
cuda
.
stream
(
self
.
single_weights_stream_mgr
.
compute_stream
):
x
=
self
.
infer_single_block
(
weights
.
single_blocks_weights
[
single_block_idx
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
)
if
single_block_idx
<
self
.
single_blocks_num
-
1
:
self
.
single_weights_stream_mgr
.
prefetch_weights
(
single_block_idx
+
1
,
weights
.
single_blocks_weights
)
self
.
single_weights_stream_mgr
.
swap_weights
()
img
=
x
[:
img_seq_len
,
...]
return
img
,
vec
def
_infer_without_offload
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
txt_seq_len
=
txt
.
shape
[
0
]
txt_seq_len
=
txt
.
shape
[
0
]
img_seq_len
=
img
.
shape
[
0
]
img_seq_len
=
img
.
shape
[
0
]
for
i
in
range
(
self
.
double_blocks_num
):
for
i
in
range
(
self
.
double_blocks_num
):
img
,
txt
=
self
.
infer_double_block
(
weights
.
double_blocks_weights
[
i
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
)
img
,
txt
=
self
.
infer_double_block
(
weights
.
double_blocks_weights
[
i
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
)
x
=
torch
.
cat
((
img
,
txt
),
0
)
x
=
torch
.
cat
((
img
,
txt
),
0
)
for
i
in
range
(
self
.
single_blocks_num
):
for
i
in
range
(
self
.
single_blocks_num
):
x
=
self
.
infer_single_block
(
weights
.
single_blocks_weights
[
i
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
)
x
=
self
.
infer_single_block
(
weights
.
single_blocks_weights
[
i
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
)
img
=
x
[:
img_seq_len
,
...]
img
=
x
[:
img_seq_len
,
...]
return
img
,
vec
return
img
,
vec
def
infer_double_block
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
def
infer_double_block
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
vec_silu
=
torch
.
nn
.
functional
.
silu
(
vec
)
vec_silu
=
torch
.
nn
.
functional
.
silu
(
vec
)
img_mod_out
=
weights
.
img_mod
.
apply
(
vec_silu
)
img_mod_out
=
weights
.
img_mod
.
apply
(
vec_silu
)
...
@@ -56,8 +141,12 @@ class HunyuanTransformerInfer:
...
@@ -56,8 +141,12 @@ class HunyuanTransformerInfer:
txt_mod2_gate
,
txt_mod2_gate
,
)
=
txt_mod_out
.
chunk
(
6
,
dim
=-
1
)
)
=
txt_mod_out
.
chunk
(
6
,
dim
=-
1
)
img_q
,
img_k
,
img_v
=
self
.
infer_double_block_img_pre_atten
(
weights
,
img
,
img_mod1_scale
,
img_mod1_shift
,
freqs_cis
)
img_q
,
img_k
,
img_v
=
self
.
infer_double_block_img_pre_atten
(
txt_q
,
txt_k
,
txt_v
=
self
.
infer_double_block_txt_pre_atten
(
weights
,
txt
,
txt_mod1_scale
,
txt_mod1_shift
)
weights
,
img
,
img_mod1_scale
,
img_mod1_shift
,
freqs_cis
)
txt_q
,
txt_k
,
txt_v
=
self
.
infer_double_block_txt_pre_atten
(
weights
,
txt
,
txt_mod1_scale
,
txt_mod1_shift
)
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
0
)
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
0
)
k
=
torch
.
cat
((
img_k
,
txt_k
),
dim
=
0
)
k
=
torch
.
cat
((
img_k
,
txt_k
),
dim
=
0
)
...
@@ -88,16 +177,38 @@ class HunyuanTransformerInfer:
...
@@ -88,16 +177,38 @@ class HunyuanTransformerInfer:
)
)
img_attn
,
txt_attn
=
attn
[:
img
.
shape
[
0
]],
attn
[
img
.
shape
[
0
]
:]
img_attn
,
txt_attn
=
attn
[:
img
.
shape
[
0
]],
attn
[
img
.
shape
[
0
]
:]
img
=
self
.
infer_double_block_img_post_atten
(
weights
,
img
,
img_attn
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
)
img
=
self
.
infer_double_block_img_post_atten
(
txt
=
self
.
infer_double_block_txt_post_atten
(
weights
,
txt
,
txt_attn
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
)
weights
,
img
,
img_attn
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
)
txt
=
self
.
infer_double_block_txt_post_atten
(
weights
,
txt
,
txt_attn
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
)
return
img
,
txt
return
img
,
txt
def
infer_double_block_img_pre_atten
(
self
,
weights
,
img
,
img_mod1_scale
,
img_mod1_shift
,
freqs_cis
):
def
infer_double_block_img_pre_atten
(
img_modulated
=
torch
.
nn
.
functional
.
layer_norm
(
img
,
(
img
.
shape
[
1
],),
None
,
None
,
1e-6
)
self
,
weights
,
img
,
img_mod1_scale
,
img_mod1_shift
,
freqs_cis
):
img_modulated
=
torch
.
nn
.
functional
.
layer_norm
(
img
,
(
img
.
shape
[
1
],),
None
,
None
,
1e-6
)
img_modulated
=
img_modulated
*
(
1
+
img_mod1_scale
)
+
img_mod1_shift
img_modulated
=
img_modulated
*
(
1
+
img_mod1_scale
)
+
img_mod1_shift
img_qkv
=
weights
.
img_attn_qkv
.
apply
(
img_modulated
)
img_qkv
=
weights
.
img_attn_qkv
.
apply
(
img_modulated
)
img_q
,
img_k
,
img_v
=
rearrange
(
img_qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
img_q
,
img_k
,
img_v
=
rearrange
(
img_qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
img_q
=
weights
.
img_attn_q_norm
.
apply
(
img_q
)
img_q
=
weights
.
img_attn_q_norm
.
apply
(
img_q
)
img_k
=
weights
.
img_attn_k_norm
.
apply
(
img_k
)
img_k
=
weights
.
img_attn_k_norm
.
apply
(
img_k
)
...
@@ -105,18 +216,33 @@ class HunyuanTransformerInfer:
...
@@ -105,18 +216,33 @@ class HunyuanTransformerInfer:
img_q
,
img_k
=
apply_rotary_emb
(
img_q
,
img_k
,
freqs_cis
)
img_q
,
img_k
=
apply_rotary_emb
(
img_q
,
img_k
,
freqs_cis
)
return
img_q
,
img_k
,
img_v
return
img_q
,
img_k
,
img_v
def
infer_double_block_txt_pre_atten
(
self
,
weights
,
txt
,
txt_mod1_scale
,
txt_mod1_shift
):
def
infer_double_block_txt_pre_atten
(
txt_modulated
=
torch
.
nn
.
functional
.
layer_norm
(
txt
,
(
txt
.
shape
[
1
],),
None
,
None
,
1e-6
)
self
,
weights
,
txt
,
txt_mod1_scale
,
txt_mod1_shift
):
txt_modulated
=
torch
.
nn
.
functional
.
layer_norm
(
txt
,
(
txt
.
shape
[
1
],),
None
,
None
,
1e-6
)
txt_modulated
=
txt_modulated
*
(
1
+
txt_mod1_scale
)
+
txt_mod1_shift
txt_modulated
=
txt_modulated
*
(
1
+
txt_mod1_scale
)
+
txt_mod1_shift
txt_qkv
=
weights
.
txt_attn_qkv
.
apply
(
txt_modulated
)
txt_qkv
=
weights
.
txt_attn_qkv
.
apply
(
txt_modulated
)
txt_q
,
txt_k
,
txt_v
=
rearrange
(
txt_qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
txt_q
,
txt_k
,
txt_v
=
rearrange
(
txt_qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
txt_q
=
weights
.
txt_attn_q_norm
.
apply
(
txt_q
)
txt_q
=
weights
.
txt_attn_q_norm
.
apply
(
txt_q
)
txt_k
=
weights
.
txt_attn_k_norm
.
apply
(
txt_k
)
txt_k
=
weights
.
txt_attn_k_norm
.
apply
(
txt_k
)
return
txt_q
,
txt_k
,
txt_v
return
txt_q
,
txt_k
,
txt_v
def
infer_double_block_img_post_atten
(
self
,
weights
,
img
,
img_attn
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
):
def
infer_double_block_img_post_atten
(
self
,
weights
,
img
,
img_attn
,
img_mod1_gate
,
img_mod2_shift
,
img_mod2_scale
,
img_mod2_gate
,
):
out
=
weights
.
img_attn_proj
.
apply
(
img_attn
)
out
=
weights
.
img_attn_proj
.
apply
(
img_attn
)
out
=
out
*
img_mod1_gate
out
=
out
*
img_mod1_gate
img
=
img
+
out
img
=
img
+
out
...
@@ -130,7 +256,16 @@ class HunyuanTransformerInfer:
...
@@ -130,7 +256,16 @@ class HunyuanTransformerInfer:
img
=
img
+
out
img
=
img
+
out
return
img
return
img
def
infer_double_block_txt_post_atten
(
self
,
weights
,
txt
,
txt_attn
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
):
def
infer_double_block_txt_post_atten
(
self
,
weights
,
txt
,
txt_attn
,
txt_mod1_gate
,
txt_mod2_shift
,
txt_mod2_scale
,
txt_mod2_gate
,
):
out
=
weights
.
txt_attn_proj
.
apply
(
txt_attn
)
out
=
weights
.
txt_attn_proj
.
apply
(
txt_attn
)
out
=
out
*
txt_mod1_gate
out
=
out
*
txt_mod1_gate
txt
=
txt
+
out
txt
=
txt
+
out
...
@@ -144,7 +279,9 @@ class HunyuanTransformerInfer:
...
@@ -144,7 +279,9 @@ class HunyuanTransformerInfer:
txt
=
txt
+
out
txt
=
txt
+
out
return
txt
return
txt
def
infer_single_block
(
self
,
weights
,
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
def
infer_single_block
(
self
,
weights
,
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
):
out
=
torch
.
nn
.
functional
.
silu
(
vec
)
out
=
torch
.
nn
.
functional
.
silu
(
vec
)
out
=
weights
.
modulation
.
apply
(
out
)
out
=
weights
.
modulation
.
apply
(
out
)
mod_shift
,
mod_scale
,
mod_gate
=
out
.
chunk
(
3
,
dim
=-
1
)
mod_shift
,
mod_scale
,
mod_gate
=
out
.
chunk
(
3
,
dim
=-
1
)
...
@@ -154,7 +291,9 @@ class HunyuanTransformerInfer:
...
@@ -154,7 +291,9 @@ class HunyuanTransformerInfer:
x_mod
=
weights
.
linear1
.
apply
(
x_mod
)
x_mod
=
weights
.
linear1
.
apply
(
x_mod
)
qkv
,
mlp
=
torch
.
split
(
x_mod
,
[
3
*
self
.
hidden_size
,
self
.
mlp_hidden_dim
],
dim
=-
1
)
qkv
,
mlp
=
torch
.
split
(
x_mod
,
[
3
*
self
.
hidden_size
,
self
.
mlp_hidden_dim
],
dim
=-
1
)
q
,
k
,
v
=
rearrange
(
qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
q
,
k
,
v
=
rearrange
(
qkv
,
"L (K H D) -> K L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
...
...
lightx2v/text2v/models/networks/hunyuan/model.py
View file @
2b0139fe
...
@@ -17,9 +17,10 @@ class HunyuanModel:
...
@@ -17,9 +17,10 @@ class HunyuanModel:
post_weight_class
=
HunyuanPostWeights
post_weight_class
=
HunyuanPostWeights
transformer_weight_class
=
HunyuanTransformerWeights
transformer_weight_class
=
HunyuanTransformerWeights
def
__init__
(
self
,
model_path
,
config
):
def
__init__
(
self
,
model_path
,
config
,
device
):
self
.
model_path
=
model_path
self
.
model_path
=
model_path
self
.
config
=
config
self
.
config
=
config
self
.
device
=
device
self
.
_init_infer_class
()
self
.
_init_infer_class
()
self
.
_init_weights
()
self
.
_init_weights
()
self
.
_init_infer
()
self
.
_init_infer
()
...
@@ -47,7 +48,7 @@ class HunyuanModel:
...
@@ -47,7 +48,7 @@ class HunyuanModel:
def
_load_ckpt
(
self
):
def
_load_ckpt
(
self
):
ckpt_path
=
os
.
path
.
join
(
self
.
model_path
,
"hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
)
ckpt_path
=
os
.
path
.
join
(
self
.
model_path
,
"hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
)
weight_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
"cuda"
,
weights_only
=
True
)[
"module"
]
weight_dict
=
torch
.
load
(
ckpt_path
,
map_location
=
self
.
device
,
weights_only
=
True
)[
"module"
]
return
weight_dict
return
weight_dict
def
_init_weights
(
self
):
def
_init_weights
(
self
):
...
@@ -82,6 +83,9 @@ class HunyuanModel:
...
@@ -82,6 +83,9 @@ class HunyuanModel:
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
text_encoder_output
,
image_encoder_output
,
args
):
def
infer
(
self
,
text_encoder_output
,
image_encoder_output
,
args
):
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
pre_infer_out
=
self
.
pre_infer
.
infer
(
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
self
.
pre_weight
,
self
.
scheduler
.
latents
,
self
.
scheduler
.
latents
,
...
@@ -95,3 +99,6 @@ class HunyuanModel:
...
@@ -95,3 +99,6 @@ class HunyuanModel:
)
)
img
,
vec
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
*
pre_infer_out
)
img
,
vec
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
*
pre_infer_out
)
self
.
scheduler
.
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
img
,
vec
,
self
.
scheduler
.
latents
.
shape
)
self
.
scheduler
.
noise_pred
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
img
,
vec
,
self
.
scheduler
.
latents
.
shape
)
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cpu
()
self
.
post_weight
.
to_cpu
()
\ No newline at end of file
lightx2v/text2v/models/networks/hunyuan/weights/transformer_weights.py
View file @
2b0139fe
...
@@ -80,20 +80,30 @@ class HunyuanTransformerDoubleBlock:
...
@@ -80,20 +80,30 @@ class HunyuanTransformerDoubleBlock:
]
]
for
mm_weight
in
self
.
weight_list
:
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
)
or
isinstance
(
mm_weight
,
RMSWeightTemplate
):
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)
)
:
mm_weight
.
set_config
(
self
.
config
[
"mm_config"
])
mm_weight
.
set_config
(
self
.
config
[
"mm_config"
])
mm_weight
.
load
(
weight_dict
)
mm_weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
def
to_cpu
(
self
):
for
mm_weight
in
self
.
weight_list
:
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
)
or
isinstance
(
mm_weight
,
RMSWeightTemplate
):
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)
)
:
mm_weight
.
to_cpu
()
mm_weight
.
to_cpu
()
def
to_cuda
(
self
):
def
to_cuda
(
self
):
for
mm_weight
in
self
.
weight_list
:
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
)
or
isinstance
(
mm_weight
,
RMSWeightTemplate
):
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)
)
:
mm_weight
.
to_cuda
()
mm_weight
.
to_cuda
()
def
to_cpu_sync
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
mm_weight
.
to_cpu
(
non_blocking
=
True
)
def
to_cuda_sync
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
mm_weight
.
to_cuda
(
non_blocking
=
True
)
class
HunyuanTransformerSingleBlock
:
class
HunyuanTransformerSingleBlock
:
def
__init__
(
self
,
block_index
,
config
):
def
__init__
(
self
,
block_index
,
config
):
...
@@ -122,16 +132,27 @@ class HunyuanTransformerSingleBlock:
...
@@ -122,16 +132,27 @@ class HunyuanTransformerSingleBlock:
]
]
for
mm_weight
in
self
.
weight_list
:
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
)
or
isinstance
(
mm_weight
,
RMSWeightTemplate
):
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)
)
:
mm_weight
.
set_config
(
self
.
config
[
"mm_config"
])
mm_weight
.
set_config
(
self
.
config
[
"mm_config"
])
mm_weight
.
load
(
weight_dict
)
mm_weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
def
to_cpu
(
self
):
for
mm_weight
in
self
.
weight_list
:
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
)
or
isinstance
(
mm_weight
,
RMSWeightTemplate
):
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)
)
:
mm_weight
.
to_cpu
()
mm_weight
.
to_cpu
()
def
to_cuda
(
self
):
def
to_cuda
(
self
):
for
mm_weight
in
self
.
weight_list
:
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
MMWeightTemplate
)
or
isinstance
(
mm_weight
,
RMSWeightTemplate
):
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)
)
:
mm_weight
.
to_cuda
()
mm_weight
.
to_cuda
()
def
to_cpu_sync
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
mm_weight
.
to_cpu
(
non_blocking
=
True
)
def
to_cuda_sync
(
self
):
for
mm_weight
in
self
.
weight_list
:
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
mm_weight
.
to_cuda
(
non_blocking
=
True
)
lightx2v/text2v/models/networks/wan/weights/transformer_weights.py
View file @
2b0139fe
...
@@ -87,9 +87,6 @@ class WanTransformerAttentionBlock:
...
@@ -87,9 +87,6 @@ class WanTransformerAttentionBlock:
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
RMSWeightTemplate
)):
if
isinstance
(
mm_weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
RMSWeightTemplate
)):
mm_weight
.
set_config
(
self
.
config
[
"mm_config"
])
mm_weight
.
set_config
(
self
.
config
[
"mm_config"
])
mm_weight
.
load
(
weight_dict
)
mm_weight
.
load
(
weight_dict
)
if
self
.
config
[
"cpu_offload"
]:
mm_weight
.
to_cpu
()
self
.
modulation
=
self
.
modulation
.
cpu
()
def
to_cpu
(
self
):
def
to_cpu
(
self
):
for
mm_weight
in
self
.
weight_list
:
for
mm_weight
in
self
.
weight_list
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment