Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
ae96fdbf
Commit
ae96fdbf
authored
Apr 29, 2025
by
helloyongyang
Browse files
Update weight modules. Simplify code.
parent
3996d421
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
285 additions
and
429 deletions
+285
-429
configs/wan_t2v.json
configs/wan_t2v.json
+2
-0
lightx2v/common/modules/__init__.py
lightx2v/common/modules/__init__.py
+0
-0
lightx2v/common/modules/weight_module.py
lightx2v/common/modules/weight_module.py
+103
-0
lightx2v/common/ops/__init__.py
lightx2v/common/ops/__init__.py
+1
-0
lightx2v/common/ops/norm/__init__.py
lightx2v/common/ops/norm/__init__.py
+1
-0
lightx2v/common/ops/tensor/__init__.py
lightx2v/common/ops/tensor/__init__.py
+1
-0
lightx2v/common/ops/tensor/tensor.py
lightx2v/common/ops/tensor/tensor.py
+16
-0
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
+6
-6
lightx2v/models/networks/hunyuan/model.py
lightx2v/models/networks/hunyuan/model.py
+3
-3
lightx2v/models/networks/hunyuan/weights/post_weights.py
lightx2v/models/networks/hunyuan/weights/post_weights.py
+5
-25
lightx2v/models/networks/hunyuan/weights/pre_weights.py
lightx2v/models/networks/hunyuan/weights/pre_weights.py
+57
-89
lightx2v/models/networks/hunyuan/weights/transformer_weights.py
...2v/models/networks/hunyuan/weights/transformer_weights.py
+29
-127
lightx2v/models/networks/wan/infer/post_infer.py
lightx2v/models/networks/wan/infer/post_infer.py
+2
-2
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+5
-5
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+4
-7
lightx2v/models/networks/wan/weights/post_weights.py
lightx2v/models/networks/wan/weights/post_weights.py
+6
-29
lightx2v/models/networks/wan/weights/pre_weights.py
lightx2v/models/networks/wan/weights/pre_weights.py
+15
-48
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+27
-88
lightx2v/utils/registry_factory.py
lightx2v/utils/registry_factory.py
+2
-0
No files found.
configs/wan_t2v.json
View file @
ae96fdbf
...
...
@@ -8,6 +8,8 @@
"seed"
:
42
,
"sample_guide_scale"
:
6
,
"sample_shift"
:
8
,
"enable_cfg"
:
true
,
"cpu_offload"
:
false
,
"mm_config"
:
{
"mm_type"
:
"W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
,
"weight_auto_quant"
:
true
...
...
lightx2v/common/modules/__init__.py
0 → 100755
View file @
ae96fdbf
lightx2v/common/modules/weight_module.py
0 → 100644
View file @
ae96fdbf
class
WeightModule
:
def
__init__
(
self
):
self
.
_modules
=
{}
self
.
_parameters
=
{}
def
add_module
(
self
,
name
,
module
):
self
.
_modules
[
name
]
=
module
setattr
(
self
,
name
,
module
)
def
register_parameter
(
self
,
name
,
param
):
self
.
_parameters
[
name
]
=
param
setattr
(
self
,
name
,
param
)
def
load
(
self
,
weight_dict
):
for
_
,
module
in
self
.
_modules
.
items
():
if
hasattr
(
module
,
"set_config"
):
module
.
set_config
(
self
.
config
[
"mm_config"
])
if
hasattr
(
module
,
"load"
):
module
.
load
(
weight_dict
)
for
_
,
parameter
in
self
.
_parameters
.
items
():
if
hasattr
(
parameter
,
"set_config"
):
parameter
.
set_config
(
self
.
config
[
"mm_config"
])
if
hasattr
(
parameter
,
"load"
):
parameter
.
load
(
weight_dict
)
def
state_dict
(
self
,
destination
=
None
,
prefix
=
""
):
if
destination
is
None
:
destination
=
{}
for
name
,
param
in
self
.
_parameters
.
items
():
if
param
is
not
None
:
destination
[
prefix
+
name
]
=
param
.
detach
().
cpu
().
clone
()
for
name
,
module
in
self
.
_modules
.
items
():
if
module
is
not
None
:
module
.
state_dict
(
destination
,
prefix
+
name
+
"."
)
return
destination
def
named_parameters
(
self
,
prefix
=
""
):
for
name
,
param
in
self
.
_parameters
.
items
():
if
param
is
not
None
:
yield
prefix
+
name
,
param
for
name
,
module
in
self
.
_modules
.
items
():
if
module
is
not
None
:
yield
from
module
.
named_parameters
(
prefix
+
name
+
"."
)
def
to_cpu
(
self
):
for
name
,
param
in
self
.
_parameters
.
items
():
if
param
is
not
None
and
hasattr
(
param
,
"cpu"
):
self
.
_parameters
[
name
]
=
param
.
cpu
()
setattr
(
self
,
name
,
self
.
_parameters
[
name
])
for
module
in
self
.
_modules
.
values
():
if
module
is
not
None
and
hasattr
(
module
,
"to_cpu"
):
module
.
to_cpu
()
def
to_cuda
(
self
):
for
name
,
param
in
self
.
_parameters
.
items
():
if
param
is
not
None
and
hasattr
(
param
,
"cuda"
):
self
.
_parameters
[
name
]
=
param
.
cuda
()
setattr
(
self
,
name
,
self
.
_parameters
[
name
])
for
module
in
self
.
_modules
.
values
():
if
module
is
not
None
and
hasattr
(
module
,
"to_cuda"
):
module
.
to_cuda
()
def
to_cpu_sync
(
self
):
for
name
,
param
in
self
.
_parameters
.
items
():
if
param
is
not
None
and
hasattr
(
param
,
"to"
):
self
.
_parameters
[
name
]
=
param
.
to
(
"cpu"
,
non_blocking
=
True
)
setattr
(
self
,
name
,
self
.
_parameters
[
name
])
for
module
in
self
.
_modules
.
values
():
if
module
is
not
None
and
hasattr
(
module
,
"to_cpu_sync"
):
module
.
to_cpu_sync
()
def
to_cuda_sync
(
self
):
for
name
,
param
in
self
.
_parameters
.
items
():
if
param
is
not
None
and
hasattr
(
param
,
"cuda"
):
self
.
_parameters
[
name
]
=
param
.
cuda
(
non_blocking
=
True
)
setattr
(
self
,
name
,
self
.
_parameters
[
name
])
for
module
in
self
.
_modules
.
values
():
if
module
is
not
None
and
hasattr
(
module
,
"to_cuda_sync"
):
module
.
to_cuda_sync
()
class
WeightModuleList
(
WeightModule
):
def
__init__
(
self
,
modules
=
None
):
super
().
__init__
()
self
.
_list
=
[]
if
modules
is
not
None
:
for
idx
,
module
in
enumerate
(
modules
):
self
.
append
(
module
)
def
append
(
self
,
module
):
idx
=
len
(
self
.
_list
)
self
.
_list
.
append
(
module
)
self
.
add_module
(
str
(
idx
),
module
)
def
__getitem__
(
self
,
idx
):
return
self
.
_list
[
idx
]
def
__len__
(
self
):
return
len
(
self
.
_list
)
def
__iter__
(
self
):
return
iter
(
self
.
_list
)
lightx2v/common/ops/__init__.py
View file @
ae96fdbf
from
.mm
import
*
from
.norm
import
*
from
.conv
import
*
from
.tensor
import
*
lightx2v/common/ops/norm/__init__.py
View file @
ae96fdbf
from
.rms_norm_weight
import
*
from
.layer_norm_weight
import
*
lightx2v/common/ops/tensor/__init__.py
0 → 100755
View file @
ae96fdbf
from
.tensor
import
DefaultTensor
lightx2v/common/ops/tensor/tensor.py
0 → 100644
View file @
ae96fdbf
from
lightx2v.utils.registry_factory
import
TENSOR_REGISTER
@
TENSOR_REGISTER
(
"Default"
)
class
DefaultTensor
:
def
__init__
(
self
,
tensor_name
):
self
.
tensor_name
=
tensor_name
def
load
(
self
,
weight_dict
):
self
.
tensor
=
weight_dict
[
self
.
tensor_name
]
def
to_cpu
(
self
,
non_blocking
=
False
):
self
.
tensor
=
self
.
tensor
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
tensor
=
self
.
tensor
.
cuda
(
non_blocking
=
non_blocking
)
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
View file @
ae96fdbf
...
...
@@ -36,14 +36,14 @@ class HunyuanTransformerInfer:
for
double_block_idx
in
range
(
self
.
double_blocks_num
):
if
double_block_idx
==
0
:
self
.
double_weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
double_blocks
_weights
[
0
]
self
.
double_weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
double_blocks
[
0
]
self
.
double_weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
with
torch
.
cuda
.
stream
(
self
.
double_weights_stream_mgr
.
compute_stream
):
img
,
txt
=
self
.
infer_double_block
(
self
.
double_weights_stream_mgr
.
active_weights
[
0
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
if
double_block_idx
<
self
.
double_blocks_num
-
1
:
self
.
double_weights_stream_mgr
.
prefetch_weights
(
double_block_idx
+
1
,
weights
.
double_blocks
_weights
)
self
.
double_weights_stream_mgr
.
prefetch_weights
(
double_block_idx
+
1
,
weights
.
double_blocks
)
self
.
double_weights_stream_mgr
.
swap_weights
()
x
=
torch
.
cat
((
img
,
txt
),
0
)
...
...
@@ -55,12 +55,12 @@ class HunyuanTransformerInfer:
for
single_block_idx
in
range
(
self
.
single_blocks_num
):
if
single_block_idx
==
0
:
self
.
single_weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
single_blocks
_weights
[
0
]
self
.
single_weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
single_blocks
[
0
]
self
.
single_weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
with
torch
.
cuda
.
stream
(
self
.
single_weights_stream_mgr
.
compute_stream
):
x
=
self
.
infer_single_block
(
self
.
single_weights_stream_mgr
.
active_weights
[
0
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
if
single_block_idx
<
self
.
single_blocks_num
-
1
:
self
.
single_weights_stream_mgr
.
prefetch_weights
(
single_block_idx
+
1
,
weights
.
single_blocks
_weights
)
self
.
single_weights_stream_mgr
.
prefetch_weights
(
single_block_idx
+
1
,
weights
.
single_blocks
)
self
.
single_weights_stream_mgr
.
swap_weights
()
torch
.
cuda
.
empty_cache
()
...
...
@@ -72,12 +72,12 @@ class HunyuanTransformerInfer:
img_seq_len
=
img
.
shape
[
0
]
for
i
in
range
(
self
.
double_blocks_num
):
img
,
txt
=
self
.
infer_double_block
(
weights
.
double_blocks
_weights
[
i
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
img
,
txt
=
self
.
infer_double_block
(
weights
.
double_blocks
[
i
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
x
=
torch
.
cat
((
img
,
txt
),
0
)
for
i
in
range
(
self
.
single_blocks_num
):
x
=
self
.
infer_single_block
(
weights
.
single_blocks
_weights
[
i
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
x
=
self
.
infer_single_block
(
weights
.
single_blocks
[
i
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
img
=
x
[:
img_seq_len
,
...]
return
img
,
vec
...
...
lightx2v/models/networks/hunyuan/model.py
View file @
ae96fdbf
...
...
@@ -64,9 +64,9 @@ class HunyuanModel:
self
.
post_weight
=
self
.
post_weight_class
(
self
.
config
)
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
)
# load weights
self
.
pre_weight
.
load
_weights
(
weight_dict
)
self
.
post_weight
.
load
_weights
(
weight_dict
)
self
.
transformer_weights
.
load
_weights
(
weight_dict
)
self
.
pre_weight
.
load
(
weight_dict
)
self
.
post_weight
.
load
(
weight_dict
)
self
.
transformer_weights
.
load
(
weight_dict
)
def
_init_infer
(
self
):
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
...
...
lightx2v/models/networks/hunyuan/weights/post_weights.py
View file @
ae96fdbf
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v.common.
ops.mm.mm_weight
import
MM
Weight
Templat
e
from
lightx2v.common.
modules.weight_module
import
Weight
Modul
e
class
HunyuanPostWeights
:
class
HunyuanPostWeights
(
WeightModule
)
:
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
def
load_weights
(
self
,
weight_dict
):
self
.
final_layer_linear
=
MM_WEIGHT_REGISTER
[
"Default-Force-FP32"
](
"final_layer.linear.weight"
,
"final_layer.linear.bias"
)
self
.
final_layer_adaLN_modulation_1
=
MM_WEIGHT_REGISTER
[
"Default"
](
"final_layer.adaLN_modulation.1.weight"
,
"final_layer.adaLN_modulation.1.bias"
)
self
.
weight_list
=
[
self
.
final_layer_linear
,
self
.
final_layer_adaLN_modulation_1
,
]
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
MMWeightTemplate
):
weight
.
set_config
(
self
.
config
[
"mm_config"
])
weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
MMWeightTemplate
):
weight
.
to_cpu
()
def
to_cuda
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
MMWeightTemplate
):
weight
.
to_cuda
()
self
.
add_module
(
"final_layer_linear"
,
MM_WEIGHT_REGISTER
[
"Default-Force-FP32"
](
"final_layer.linear.weight"
,
"final_layer.linear.bias"
))
self
.
add_module
(
"final_layer_adaLN_modulation_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"final_layer.adaLN_modulation.1.weight"
,
"final_layer.adaLN_modulation.1.bias"
))
lightx2v/models/networks/hunyuan/weights/pre_weights.py
View file @
ae96fdbf
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
CONV3D_WEIGHT_REGISTER
from
lightx2v.common.ops.mm.mm_weight
import
MMWeightTemplate
from
lightx2v.common.ops.norm.layer_norm_weight
import
LNWeightTemplate
from
lightx2v.common.ops.conv.conv3d
import
Conv3dWeightTemplate
from
lightx2v.common.modules.weight_module
import
WeightModule
class
HunyuanPreWeights
:
class
HunyuanPreWeights
(
WeightModule
)
:
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
def
load_weights
(
self
,
weight_dict
):
self
.
img_in_proj
=
CONV3D_WEIGHT_REGISTER
[
"Default"
](
"img_in.proj.weight"
,
"img_in.proj.bias"
,
stride
=
(
1
,
2
,
2
))
self
.
add_module
(
"img_in_proj"
,
CONV3D_WEIGHT_REGISTER
[
"Default"
](
"img_in.proj.weight"
,
"img_in.proj.bias"
,
stride
=
(
1
,
2
,
2
)))
self
.
txt_in_input_embedder
=
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.input_embedder.weight"
,
"txt_in.input_embedder.bias"
)
self
.
txt_in_t_embedder_mlp_0
=
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.t_embedder.mlp.0.weight"
,
"txt_in.t_embedder.mlp.0.bias"
)
self
.
txt_in_t_embedder_mlp_2
=
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.t_embedder.mlp.2.weight"
,
"txt_in.t_embedder.mlp.2.bias"
)
self
.
txt_in_c_embedder_linear_1
=
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.c_embedder.linear_1.weight"
,
"txt_in.c_embedder.linear_1.bias"
)
self
.
txt_in_c_embedder_linear_2
=
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.c_embedder.linear_2.weight"
,
"txt_in.c_embedder.linear_2.bias"
)
self
.
add_module
(
"
txt_in_input_embedder
"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.input_embedder.weight"
,
"txt_in.input_embedder.bias"
)
)
self
.
add_module
(
"
txt_in_t_embedder_mlp_0
"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.t_embedder.mlp.0.weight"
,
"txt_in.t_embedder.mlp.0.bias"
)
)
self
.
add_module
(
"
txt_in_t_embedder_mlp_2
"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.t_embedder.mlp.2.weight"
,
"txt_in.t_embedder.mlp.2.bias"
)
)
self
.
add_module
(
"
txt_in_c_embedder_linear_1
"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.c_embedder.linear_1.weight"
,
"txt_in.c_embedder.linear_1.bias"
)
)
self
.
add_module
(
"
txt_in_c_embedder_linear_2
"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.c_embedder.linear_2.weight"
,
"txt_in.c_embedder.linear_2.bias"
)
)
self
.
txt_in_individual_token_refiner_blocks_0_norm1
=
LN_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.norm1.weight"
,
"txt_in.individual_token_refiner.blocks.0.norm1.bias"
,
eps
=
1e-6
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_0_norm1"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.norm1.weight"
,
"txt_in.individual_token_refiner.blocks.0.norm1.bias"
,
eps
=
1e-6
),
)
self
.
txt_in_individual_token_refiner_blocks_0_self_attn_qkv
=
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.self_attn_qkv.weight"
,
"txt_in.individual_token_refiner.blocks.0.self_attn_qkv.bias"
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_0_self_attn_qkv"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.self_attn_qkv.weight"
,
"txt_in.individual_token_refiner.blocks.0.self_attn_qkv.bias"
),
)
self
.
txt_in_individual_token_refiner_blocks_0_self_attn_proj
=
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.self_attn_proj.weight"
,
"txt_in.individual_token_refiner.blocks.0.self_attn_proj.bias"
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_0_self_attn_proj"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.self_attn_proj.weight"
,
"txt_in.individual_token_refiner.blocks.0.self_attn_proj.bias"
),
)
self
.
txt_in_individual_token_refiner_blocks_0_norm2
=
LN_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.norm2.weight"
,
"txt_in.individual_token_refiner.blocks.0.norm2.bias"
,
eps
=
1e-6
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_0_norm2"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.norm2.weight"
,
"txt_in.individual_token_refiner.blocks.0.norm2.bias"
,
eps
=
1e-6
),
)
self
.
txt_in_individual_token_refiner_blocks_0_mlp_fc1
=
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.mlp.fc1.weight"
,
"txt_in.individual_token_refiner.blocks.0.mlp.fc1.bias"
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_0_mlp_fc1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.mlp.fc1.weight"
,
"txt_in.individual_token_refiner.blocks.0.mlp.fc1.bias"
),
)
self
.
txt_in_individual_token_refiner_blocks_0_mlp_fc2
=
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.mlp.fc2.weight"
,
"txt_in.individual_token_refiner.blocks.0.mlp.fc2.bias"
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_0_mlp_fc2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.mlp.fc2.weight"
,
"txt_in.individual_token_refiner.blocks.0.mlp.fc2.bias"
),
)
self
.
txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1
=
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.weight"
,
"txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias"
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.weight"
,
"txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias"
),
)
self
.
txt_in_individual_token_refiner_blocks_1_norm1
=
LN_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.norm1.weight"
,
"txt_in.individual_token_refiner.blocks.1.norm1.bias"
,
eps
=
1e-6
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_1_norm1"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.norm1.weight"
,
"txt_in.individual_token_refiner.blocks.1.norm1.bias"
,
eps
=
1e-6
),
)
self
.
txt_in_individual_token_refiner_blocks_1_self_attn_qkv
=
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight"
,
"txt_in.individual_token_refiner.blocks.1.self_attn_qkv.bias"
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_1_self_attn_qkv"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight"
,
"txt_in.individual_token_refiner.blocks.1.self_attn_qkv.bias"
),
)
self
.
txt_in_individual_token_refiner_blocks_1_self_attn_proj
=
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.self_attn_proj.weight"
,
"txt_in.individual_token_refiner.blocks.1.self_attn_proj.bias"
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_1_self_attn_proj"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.self_attn_proj.weight"
,
"txt_in.individual_token_refiner.blocks.1.self_attn_proj.bias"
),
)
self
.
txt_in_individual_token_refiner_blocks_1_norm2
=
LN_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.norm2.weight"
,
"txt_in.individual_token_refiner.blocks.1.norm2.bias"
,
eps
=
1e-6
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_1_norm2"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.norm2.weight"
,
"txt_in.individual_token_refiner.blocks.1.norm2.bias"
,
eps
=
1e-6
),
)
self
.
txt_in_individual_token_refiner_blocks_1_mlp_fc1
=
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.mlp.fc1.weight"
,
"txt_in.individual_token_refiner.blocks.1.mlp.fc1.bias"
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_1_mlp_fc1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.mlp.fc1.weight"
,
"txt_in.individual_token_refiner.blocks.1.mlp.fc1.bias"
),
)
self
.
txt_in_individual_token_refiner_blocks_1_mlp_fc2
=
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.mlp.fc2.weight"
,
"txt_in.individual_token_refiner.blocks.1.mlp.fc2.bias"
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_1_mlp_fc2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.mlp.fc2.weight"
,
"txt_in.individual_token_refiner.blocks.1.mlp.fc2.bias"
),
)
self
.
txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1
=
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.weight"
,
"txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.bias"
self
.
add_module
(
"txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.weight"
,
"txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.bias"
),
)
self
.
time_in_mlp_0
=
MM_WEIGHT_REGISTER
[
"Default"
](
"time_in.mlp.0.weight"
,
"time_in.mlp.0.bias"
)
self
.
time_in_mlp_2
=
MM_WEIGHT_REGISTER
[
"Default"
](
"time_in.mlp.2.weight"
,
"time_in.mlp.2.bias"
)
self
.
vector_in_in_layer
=
MM_WEIGHT_REGISTER
[
"Default"
](
"vector_in.in_layer.weight"
,
"vector_in.in_layer.bias"
)
self
.
vector_in_out_layer
=
MM_WEIGHT_REGISTER
[
"Default"
](
"vector_in.out_layer.weight"
,
"vector_in.out_layer.bias"
)
self
.
guidance_in_mlp_0
=
MM_WEIGHT_REGISTER
[
"Default"
](
"guidance_in.mlp.0.weight"
,
"guidance_in.mlp.0.bias"
)
self
.
guidance_in_mlp_2
=
MM_WEIGHT_REGISTER
[
"Default"
](
"guidance_in.mlp.2.weight"
,
"guidance_in.mlp.2.bias"
)
self
.
weight_list
=
[
self
.
img_in_proj
,
self
.
txt_in_input_embedder
,
self
.
txt_in_t_embedder_mlp_0
,
self
.
txt_in_t_embedder_mlp_2
,
self
.
txt_in_c_embedder_linear_1
,
self
.
txt_in_c_embedder_linear_2
,
self
.
txt_in_individual_token_refiner_blocks_0_norm1
,
self
.
txt_in_individual_token_refiner_blocks_0_self_attn_qkv
,
self
.
txt_in_individual_token_refiner_blocks_0_self_attn_proj
,
self
.
txt_in_individual_token_refiner_blocks_0_norm2
,
self
.
txt_in_individual_token_refiner_blocks_0_mlp_fc1
,
self
.
txt_in_individual_token_refiner_blocks_0_mlp_fc2
,
self
.
txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1
,
self
.
txt_in_individual_token_refiner_blocks_1_norm1
,
self
.
txt_in_individual_token_refiner_blocks_1_self_attn_qkv
,
self
.
txt_in_individual_token_refiner_blocks_1_self_attn_proj
,
self
.
txt_in_individual_token_refiner_blocks_1_norm2
,
self
.
txt_in_individual_token_refiner_blocks_1_mlp_fc1
,
self
.
txt_in_individual_token_refiner_blocks_1_mlp_fc2
,
self
.
txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1
,
self
.
time_in_mlp_0
,
self
.
time_in_mlp_2
,
self
.
vector_in_in_layer
,
self
.
vector_in_out_layer
,
self
.
guidance_in_mlp_0
,
self
.
guidance_in_mlp_2
,
]
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
MMWeightTemplate
)
or
isinstance
(
weight
,
LNWeightTemplate
)
or
isinstance
(
weight
,
Conv3dWeightTemplate
):
weight
.
set_config
(
self
.
config
[
"mm_config"
])
weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
MMWeightTemplate
)
or
isinstance
(
weight
,
LNWeightTemplate
)
or
isinstance
(
weight
,
Conv3dWeightTemplate
):
weight
.
to_cpu
()
def
to_cuda
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
MMWeightTemplate
)
or
isinstance
(
weight
,
LNWeightTemplate
)
or
isinstance
(
weight
,
Conv3dWeightTemplate
):
weight
.
to_cuda
()
self
.
add_module
(
"time_in_mlp_0"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"time_in.mlp.0.weight"
,
"time_in.mlp.0.bias"
))
self
.
add_module
(
"time_in_mlp_2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"time_in.mlp.2.weight"
,
"time_in.mlp.2.bias"
))
self
.
add_module
(
"vector_in_in_layer"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"vector_in.in_layer.weight"
,
"vector_in.in_layer.bias"
))
self
.
add_module
(
"vector_in_out_layer"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"vector_in.out_layer.weight"
,
"vector_in.out_layer.bias"
))
self
.
add_module
(
"guidance_in_mlp_0"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"guidance_in.mlp.0.weight"
,
"guidance_in.mlp.0.bias"
))
self
.
add_module
(
"guidance_in_mlp_2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"guidance_in.mlp.2.weight"
,
"guidance_in.mlp.2.bias"
))
lightx2v/models/networks/hunyuan/weights/transformer_weights.py
View file @
ae96fdbf
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v.common.ops.norm.rms_norm_weight
import
RMS_WEIGHT_REGISTER
from
lightx2v.common.ops.mm.mm_weight
import
MMWeightTemplate
from
lightx2v.common.ops.norm.rms_norm_weight
import
RMSWeightTemplate
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
RMS_WEIGHT_REGISTER
from
lightx2v.common.modules.weight_module
import
WeightModule
,
WeightModuleList
class
HunyuanTransformerWeights
:
class
HunyuanTransformerWeights
(
WeightModule
)
:
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
init
()
def
init
(
self
):
self
.
double_blocks_num
=
20
self
.
single_blocks_num
=
40
def
load_weights
(
self
,
weight_dict
):
self
.
double_blocks_weights
=
[
HunyuanTransformerDoubleBlock
(
i
,
self
.
config
)
for
i
in
range
(
self
.
double_blocks_num
)]
self
.
single_blocks_weights
=
[
HunyuanTransformerSingleBlock
(
i
,
self
.
config
)
for
i
in
range
(
self
.
single_blocks_num
)]
for
double_block
in
self
.
double_blocks_weights
:
double_block
.
load_weights
(
weight_dict
)
for
single_block
in
self
.
single_blocks_weights
:
single_block
.
load_weights
(
weight_dict
)
def
to_cpu
(
self
):
for
double_block
in
self
.
double_blocks_weights
:
double_block
.
to_cpu
()
for
single_block
in
self
.
single_blocks_weights
:
single_block
.
to_cpu
()
self
.
add_module
(
"double_blocks"
,
WeightModuleList
([
HunyuanTransformerDoubleBlock
(
i
,
self
.
config
)
for
i
in
range
(
self
.
double_blocks_num
)]))
self
.
add_module
(
"single_blocks"
,
WeightModuleList
([
HunyuanTransformerSingleBlock
(
i
,
self
.
config
)
for
i
in
range
(
self
.
single_blocks_num
)]))
def
to_cuda
(
self
):
for
double_block
in
self
.
double_blocks_weights
:
double_block
.
to_cuda
()
for
single_block
in
self
.
single_blocks_weights
:
single_block
.
to_cuda
()
class
HunyuanTransformerDoubleBlock
:
class
HunyuanTransformerDoubleBlock
(
WeightModule
):
def
__init__
(
self
,
block_index
,
config
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
config
=
config
self
.
weight_list
=
[]
def
load_weights
(
self
,
weight_dict
):
if
self
.
config
[
"do_mm_calib"
]:
mm_type
=
"Calib"
else
:
mm_type
=
self
.
config
[
"mm_config"
].
get
(
"mm_type"
,
"Default"
)
if
self
.
config
[
"mm_config"
]
else
"Default"
self
.
img_mod
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_mod.linear.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_mod.linear.bias"
)
self
.
img_attn_qkv
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_qkv.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_attn_qkv.bias"
)
self
.
img_attn_q_norm
=
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_q_norm.weight"
,
eps
=
1e-6
)
self
.
img_attn_k_norm
=
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_k_norm.weight"
,
eps
=
1e-6
)
self
.
img_attn_proj
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_proj.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_attn_proj.bias"
)
self
.
img_mlp_fc1
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc1.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc1.bias"
)
self
.
img_mlp_fc2
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc2.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc2.bias"
)
self
.
txt_mod
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_mod.linear.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_mod.linear.bias"
)
self
.
txt_attn_qkv
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_qkv.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_qkv.bias"
)
self
.
txt_attn_q_norm
=
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_q_norm.weight"
,
eps
=
1e-6
)
self
.
txt_attn_k_norm
=
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_k_norm.weight"
,
eps
=
1e-6
)
self
.
txt_attn_proj
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_proj.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_proj.bias"
)
self
.
txt_mlp_fc1
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc1.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc1.bias"
)
self
.
txt_mlp_fc2
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc2.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc2.bias"
)
self
.
weight_list
=
[
self
.
img_mod
,
self
.
img_attn_qkv
,
self
.
img_attn_q_norm
,
self
.
img_attn_k_norm
,
self
.
img_attn_proj
,
self
.
img_mlp_fc1
,
self
.
img_mlp_fc2
,
self
.
txt_mod
,
self
.
txt_attn_qkv
,
self
.
txt_attn_q_norm
,
self
.
txt_attn_k_norm
,
self
.
txt_attn_proj
,
self
.
txt_mlp_fc1
,
self
.
txt_mlp_fc2
,
]
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
weight
.
set_config
(
self
.
config
[
"mm_config"
])
weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
weight
.
to_cpu
()
def
to_cuda
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
weight
.
to_cuda
()
self
.
add_module
(
"img_mod"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_mod.linear.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_mod.linear.bias"
))
self
.
add_module
(
"img_attn_qkv"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_qkv.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_attn_qkv.bias"
))
self
.
add_module
(
"img_attn_q_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_q_norm.weight"
,
eps
=
1e-6
))
self
.
add_module
(
"img_attn_k_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_k_norm.weight"
,
eps
=
1e-6
))
self
.
add_module
(
"img_attn_proj"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_proj.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_attn_proj.bias"
))
self
.
add_module
(
"img_mlp_fc1"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc1.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc1.bias"
))
self
.
add_module
(
"img_mlp_fc2"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc2.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc2.bias"
))
def
to_cpu_sync
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
weight
.
to_cpu
(
non_blocking
=
True
)
self
.
add_module
(
"txt_mod"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_mod.linear.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_mod.linear.bias"
))
self
.
add_module
(
"txt_attn_qkv"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_qkv.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_qkv.bias"
))
self
.
add_module
(
"txt_attn_q_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_q_norm.weight"
,
eps
=
1e-6
))
self
.
add_module
(
"txt_attn_k_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_k_norm.weight"
,
eps
=
1e-6
))
self
.
add_module
(
"txt_attn_proj"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_proj.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_proj.bias"
))
self
.
add_module
(
"txt_mlp_fc1"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc1.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc1.bias"
))
self
.
add_module
(
"txt_mlp_fc2"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc2.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc2.bias"
))
def
to_cuda_sync
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
weight
.
to_cuda
(
non_blocking
=
True
)
class
HunyuanTransformerSingleBlock
:
class
HunyuanTransformerSingleBlock
(
WeightModule
):
def
__init__
(
self
,
block_index
,
config
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
config
=
config
self
.
weight_list
=
[]
def
load_weights
(
self
,
weight_dict
):
if
self
.
config
[
"do_mm_calib"
]:
mm_type
=
"Calib"
else
:
mm_type
=
self
.
config
[
"mm_config"
].
get
(
"mm_type"
,
"Default"
)
if
self
.
config
[
"mm_config"
]
else
"Default"
self
.
linear1
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"single_blocks.
{
self
.
block_index
}
.linear1.weight"
,
f
"single_blocks.
{
self
.
block_index
}
.linear1.bias"
)
self
.
linear2
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"single_blocks.
{
self
.
block_index
}
.linear2.weight"
,
f
"single_blocks.
{
self
.
block_index
}
.linear2.bias"
)
self
.
q_norm
=
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"single_blocks.
{
self
.
block_index
}
.q_norm.weight"
,
eps
=
1e-6
)
self
.
k_norm
=
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"single_blocks.
{
self
.
block_index
}
.k_norm.weight"
,
eps
=
1e-6
)
self
.
modulation
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"single_blocks.
{
self
.
block_index
}
.modulation.linear.weight"
,
f
"single_blocks.
{
self
.
block_index
}
.modulation.linear.bias"
)
self
.
weight_list
=
[
self
.
linear1
,
self
.
linear2
,
self
.
q_norm
,
self
.
k_norm
,
self
.
modulation
,
]
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
weight
.
set_config
(
self
.
config
[
"mm_config"
])
weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
weight
.
to_cpu
()
def
to_cuda
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
weight
.
to_cuda
()
def
to_cpu_sync
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
weight
.
to_cpu
(
non_blocking
=
True
)
def
to_cuda_sync
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
weight
.
to_cuda
(
non_blocking
=
True
)
self
.
add_module
(
"linear1"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"single_blocks.
{
self
.
block_index
}
.linear1.weight"
,
f
"single_blocks.
{
self
.
block_index
}
.linear1.bias"
))
self
.
add_module
(
"linear2"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"single_blocks.
{
self
.
block_index
}
.linear2.weight"
,
f
"single_blocks.
{
self
.
block_index
}
.linear2.bias"
))
self
.
add_module
(
"q_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"single_blocks.
{
self
.
block_index
}
.q_norm.weight"
,
eps
=
1e-6
))
self
.
add_module
(
"k_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"single_blocks.
{
self
.
block_index
}
.k_norm.weight"
,
eps
=
1e-6
))
self
.
add_module
(
"modulation"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"single_blocks.
{
self
.
block_index
}
.modulation.linear.weight"
,
f
"single_blocks.
{
self
.
block_index
}
.modulation.linear.bias"
))
lightx2v/models/networks/wan/infer/post_infer.py
View file @
ae96fdbf
...
...
@@ -13,10 +13,10 @@ class WanPostInfer:
def
infer
(
self
,
weights
,
x
,
e
,
grid_sizes
):
if
e
.
dim
()
==
2
:
modulation
=
weights
.
head_modulation
# 1, 2, dim
modulation
=
weights
.
head_modulation
.
tensor
# 1, 2, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
elif
e
.
dim
()
==
3
:
# For Diffustion forcing
modulation
=
weights
.
head_modulation
.
unsqueeze
(
2
)
# 1, 2, seq, dim
modulation
=
weights
.
head_modulation
.
tensor
.
unsqueeze
(
2
)
# 1, 2, seq, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
e
=
[
ei
.
squeeze
(
1
)
for
ei
in
e
]
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
ae96fdbf
...
...
@@ -42,7 +42,7 @@ class WanTransformerInfer:
def
_infer_with_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
for
block_idx
in
range
(
self
.
blocks_num
):
if
block_idx
==
0
:
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
blocks
_weights
[
0
]
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
blocks
[
0
]
self
.
weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
...
...
@@ -58,7 +58,7 @@ class WanTransformerInfer:
)
if
block_idx
<
self
.
blocks_num
-
1
:
self
.
weights_stream_mgr
.
prefetch_weights
(
block_idx
+
1
,
weights
.
blocks
_weights
)
self
.
weights_stream_mgr
.
prefetch_weights
(
block_idx
+
1
,
weights
.
blocks
)
self
.
weights_stream_mgr
.
swap_weights
()
return
x
...
...
@@ -66,7 +66,7 @@ class WanTransformerInfer:
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
for
block_idx
in
range
(
self
.
blocks_num
):
x
=
self
.
infer_block
(
weights
.
blocks
_weights
[
block_idx
],
weights
.
blocks
[
block_idx
],
grid_sizes
,
embed
,
x
,
...
...
@@ -79,12 +79,12 @@ class WanTransformerInfer:
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
if
embed0
.
dim
()
==
3
:
modulation
=
weights
.
modulation
.
unsqueeze
(
2
)
# 1, 6, 1, dim
modulation
=
weights
.
modulation
.
tensor
.
unsqueeze
(
2
)
# 1, 6, 1, dim
embed0
=
embed0
.
unsqueeze
(
0
)
#
embed0
=
(
modulation
+
embed0
).
chunk
(
6
,
dim
=
1
)
embed0
=
[
ei
.
squeeze
(
1
)
for
ei
in
embed0
]
elif
embed0
.
dim
()
==
2
:
embed0
=
(
weights
.
modulation
+
embed0
).
chunk
(
6
,
dim
=
1
)
embed0
=
(
weights
.
modulation
.
tensor
+
embed0
).
chunk
(
6
,
dim
=
1
)
norm1_out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
)
norm1_out
=
(
norm1_out
*
(
1
+
embed0
[
1
])
+
embed0
[
0
]).
squeeze
(
0
)
...
...
lightx2v/models/networks/wan/model.py
View file @
ae96fdbf
...
...
@@ -84,9 +84,9 @@ class WanModel:
self
.
post_weight
=
self
.
post_weight_class
(
self
.
config
)
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
)
# load weights
self
.
pre_weight
.
load
_weights
(
self
.
original_weight_dict
)
self
.
post_weight
.
load
_weights
(
self
.
original_weight_dict
)
self
.
transformer_weights
.
load
_weights
(
self
.
original_weight_dict
)
self
.
pre_weight
.
load
(
self
.
original_weight_dict
)
self
.
post_weight
.
load
(
self
.
original_weight_dict
)
self
.
transformer_weights
.
load
(
self
.
original_weight_dict
)
def
_init_infer
(
self
):
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
...
...
@@ -109,9 +109,6 @@ class WanModel:
self
.
post_weight
.
to_cuda
()
self
.
transformer_weights
.
to_cuda
()
def
do_classifier_free_guidance
(
self
)
->
bool
:
return
self
.
config
.
sample_guide_scale
>
1
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
if
self
.
config
[
"cpu_offload"
]:
...
...
@@ -128,7 +125,7 @@ class WanModel:
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
noise_pred
=
noise_pred_cond
if
self
.
do_classifier_free_guidance
()
:
if
self
.
config
[
"enable_cfg"
]
:
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
...
...
lightx2v/models/networks/wan/weights/post_weights.py
View file @
ae96fdbf
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v.common.
ops.mm.mm_weight
import
MM
Weight
Templat
e
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
TENSOR_REGISTER
from
lightx2v.common.
modules.weight_module
import
Weight
Modul
e
class
WanPostWeights
:
class
WanPostWeights
(
WeightModule
)
:
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
def
load_weights
(
self
,
weight_dict
):
self
.
head
=
MM_WEIGHT_REGISTER
[
"Default"
](
"head.head.weight"
,
"head.head.bias"
)
self
.
head_modulation
=
weight_dict
[
"head.modulation"
]
self
.
weight_list
=
[
self
.
head
]
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
MMWeightTemplate
):
weight
.
set_config
(
self
.
config
[
"mm_config"
])
weight
.
load
(
weight_dict
)
if
self
.
config
[
"cpu_offload"
]:
weight
.
to_cpu
()
self
.
head_modulation
=
self
.
head_modulation
.
cpu
()
def
to_cpu
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
MMWeightTemplate
):
weight
.
to_cpu
()
self
.
head_modulation
=
self
.
head_modulation
.
cpu
()
def
to_cuda
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
MMWeightTemplate
):
weight
.
to_cuda
()
self
.
head_modulation
=
self
.
head_modulation
.
cuda
()
self
.
add_module
(
"head"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"head.head.weight"
,
"head.head.bias"
))
self
.
register_parameter
(
"head_modulation"
,
TENSOR_REGISTER
[
"Default"
](
"head.modulation"
))
lightx2v/models/networks/wan/weights/pre_weights.py
View file @
ae96fdbf
import
torch
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
CONV3D_WEIGHT_REGISTER
from
lightx2v.common.ops.mm.mm_weight
import
MMWeightTemplate
from
lightx2v.common.ops.norm.layer_norm_weight
import
LNWeightTemplate
from
lightx2v.common.ops.conv.conv3d
import
Conv3dWeightTemplate
from
lightx2v.common.modules.weight_module
import
WeightModule
class
WanPreWeights
:
class
WanPreWeights
(
WeightModule
)
:
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
in_dim
=
config
[
"in_dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
config
=
config
def
load_weights
(
self
,
weight_dict
):
self
.
patch_embedding
=
CONV3D_WEIGHT_REGISTER
[
"Defaultt-Force-BF16"
](
"patch_embedding.weight"
,
"patch_embedding.bias"
,
stride
=
self
.
patch_size
)
self
.
text_embedding_0
=
MM_WEIGHT_REGISTER
[
"Default"
](
"text_embedding.0.weight"
,
"text_embedding.0.bias"
)
self
.
text_embedding_2
=
MM_WEIGHT_REGISTER
[
"Default"
](
"text_embedding.2.weight"
,
"text_embedding.2.bias"
)
self
.
time_embedding_0
=
MM_WEIGHT_REGISTER
[
"Default"
](
"time_embedding.0.weight"
,
"time_embedding.0.bias"
)
self
.
time_embedding_2
=
MM_WEIGHT_REGISTER
[
"Default"
](
"time_embedding.2.weight"
,
"time_embedding.2.bias"
)
self
.
time_projection_1
=
MM_WEIGHT_REGISTER
[
"Default"
](
"time_projection.1.weight"
,
"time_projection.1.bias"
)
self
.
weight_list
=
[
self
.
patch_embedding
,
self
.
text_embedding_0
,
self
.
text_embedding_2
,
self
.
time_embedding_0
,
self
.
time_embedding_2
,
self
.
time_projection_1
,
]
if
"img_emb.proj.0.weight"
in
weight_dict
.
keys
():
self
.
proj_0
=
LN_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.0.weight"
,
"img_emb.proj.0.bias"
,
eps
=
1e-5
)
self
.
proj_1
=
MM_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.1.weight"
,
"img_emb.proj.1.bias"
)
self
.
proj_3
=
MM_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.3.weight"
,
"img_emb.proj.3.bias"
)
self
.
proj_4
=
LN_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.4.weight"
,
"img_emb.proj.4.bias"
,
eps
=
1e-5
)
self
.
weight_list
.
append
(
self
.
proj_0
)
self
.
weight_list
.
append
(
self
.
proj_1
)
self
.
weight_list
.
append
(
self
.
proj_3
)
self
.
weight_list
.
append
(
self
.
proj_4
)
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
Conv3dWeightTemplate
)):
weight
.
set_config
(
self
.
config
[
"mm_config"
])
weight
.
load
(
weight_dict
)
if
self
.
config
[
"cpu_offload"
]:
weight
.
to_cpu
()
def
to_cpu
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
Conv3dWeightTemplate
)):
weight
.
to_cpu
()
def
to_cuda
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
Conv3dWeightTemplate
)):
weight
.
to_cuda
()
self
.
add_module
(
"patch_embedding"
,
CONV3D_WEIGHT_REGISTER
[
"Defaultt-Force-BF16"
](
"patch_embedding.weight"
,
"patch_embedding.bias"
,
stride
=
self
.
patch_size
))
self
.
add_module
(
"text_embedding_0"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"text_embedding.0.weight"
,
"text_embedding.0.bias"
))
self
.
add_module
(
"text_embedding_2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"text_embedding.2.weight"
,
"text_embedding.2.bias"
))
self
.
add_module
(
"time_embedding_0"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"time_embedding.0.weight"
,
"time_embedding.0.bias"
))
self
.
add_module
(
"time_embedding_2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"time_embedding.2.weight"
,
"time_embedding.2.bias"
))
self
.
add_module
(
"time_projection_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"time_projection.1.weight"
,
"time_projection.1.bias"
))
if
config
.
task
==
"i2v"
:
self
.
add_module
(
"proj_0"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.0.weight"
,
"img_emb.proj.0.bias"
,
eps
=
1e-5
))
self
.
add_module
(
"proj_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.1.weight"
,
"img_emb.proj.1.bias"
))
self
.
add_module
(
"proj_3"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.3.weight"
,
"img_emb.proj.3.bias"
))
self
.
add_module
(
"proj_4"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.4.weight"
,
"img_emb.proj.4.bias"
,
eps
=
1e-5
))
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
ae96fdbf
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
RMS_WEIGHT_REGISTER
from
lightx2v.common.ops.mm.mm_weight
import
MMWeightTemplate
from
lightx2v.common.ops.norm.layer_norm_weight
import
LNWeightTemplate
from
lightx2v.common.ops.norm.rms_norm_weight
import
RMSWeightTemplate
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
RMS_WEIGHT_REGISTER
,
TENSOR_REGISTER
from
lightx2v.common.modules.weight_module
import
WeightModule
,
WeightModuleList
class
WanTransformerWeights
:
class
WanTransformerWeights
(
WeightModule
)
:
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
blocks_num
=
config
[
"num_layers"
]
self
.
task
=
config
[
"task"
]
self
.
config
=
config
...
...
@@ -13,99 +12,39 @@ class WanTransformerWeights:
self
.
mm_type
=
"Calib"
else
:
self
.
mm_type
=
config
[
"mm_config"
].
get
(
"mm_type"
,
"Default"
)
if
config
[
"mm_config"
]
else
"Default"
self
.
blocks
=
WeightModuleList
([
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
)
for
i
in
range
(
self
.
blocks_num
)])
self
.
add_module
(
"blocks"
,
self
.
blocks
)
def
load_weights
(
self
,
weight_dict
):
self
.
blocks_weights
=
[
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
)
for
i
in
range
(
self
.
blocks_num
)]
for
block
in
self
.
blocks_weights
:
block
.
load_weights
(
weight_dict
)
def
to_cpu
(
self
):
for
block
in
self
.
blocks_weights
:
block
.
to_cpu
()
def
to_cuda
(
self
):
for
block
in
self
.
blocks_weights
:
block
.
to_cuda
()
class
WanTransformerAttentionBlock
:
class
WanTransformerAttentionBlock
(
WeightModule
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
task
=
task
self
.
config
=
config
def
load_weights
(
self
,
weight_dict
):
self
.
self_attn_q
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.q.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.q.bias"
)
self
.
self_attn_k
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.k.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.k.bias"
)
self
.
self_attn_v
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.v.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.v.bias"
)
self
.
self_attn_o
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.o.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.o.bias"
)
self
.
self_attn_norm_q
=
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.self_attn.norm_q.weight"
)
self
.
self_attn_norm_k
=
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.self_attn.norm_k.weight"
)
self
.
norm3
=
LN_WEIGHT_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.norm3.weight"
,
f
"blocks.
{
self
.
block_index
}
.norm3.bias"
,
eps
=
1e-6
)
self
.
cross_attn_q
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.q.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.q.bias"
)
self
.
cross_attn_k
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.k.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.k.bias"
)
self
.
cross_attn_v
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.v.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.v.bias"
)
self
.
cross_attn_o
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.o.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.o.bias"
)
self
.
cross_attn_norm_q
=
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.norm_q.weight"
)
self
.
cross_attn_norm_k
=
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.norm_k.weight"
)
self
.
add_module
(
"self_attn_q"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.q.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.q.bias"
))
self
.
add_module
(
"self_attn_k"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.k.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.k.bias"
))
self
.
add_module
(
"self_attn_v"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.v.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.v.bias"
))
self
.
add_module
(
"self_attn_o"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.o.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.o.bias"
))
self
.
add_module
(
"self_attn_norm_q"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.self_attn.norm_q.weight"
))
self
.
add_module
(
"self_attn_norm_k"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.self_attn.norm_k.weight"
))
self
.
ffn_0
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.ffn.0.weight"
,
f
"blocks.
{
self
.
block_index
}
.ffn.0.bias"
)
self
.
ffn_2
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.ffn.2.weight"
,
f
"blocks.
{
self
.
block_index
}
.ffn.2.bias"
)
self
.
modulation
=
weight_dict
[
f
"blocks.
{
self
.
block_index
}
.modulation"
]
self
.
add_module
(
"norm3"
,
LN_WEIGHT_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.norm3.weight"
,
f
"blocks.
{
self
.
block_index
}
.norm3.bias"
,
eps
=
1e-6
))
self
.
add_module
(
"cross_attn_q"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.q.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.q.bias"
))
self
.
add_module
(
"cross_attn_k"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.k.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.k.bias"
))
self
.
add_module
(
"cross_attn_v"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.v.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.v.bias"
))
self
.
add_module
(
"cross_attn_o"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.o.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.o.bias"
))
self
.
add_module
(
"cross_attn_norm_q"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.norm_q.weight"
))
self
.
add_module
(
"cross_attn_norm_k"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.norm_k.weight"
))
self
.
weight_list
=
[
self
.
self_attn_q
,
self
.
self_attn_k
,
self
.
self_attn_v
,
self
.
self_attn_o
,
self
.
self_attn_norm_q
,
self
.
self_attn_norm_k
,
self
.
norm3
,
self
.
cross_attn_q
,
self
.
cross_attn_k
,
self
.
cross_attn_v
,
self
.
cross_attn_o
,
self
.
cross_attn_norm_q
,
self
.
cross_attn_norm_k
,
self
.
ffn_0
,
self
.
ffn_2
,
]
self
.
add_module
(
"ffn_0"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.ffn.0.weight"
,
f
"blocks.
{
self
.
block_index
}
.ffn.0.bias"
))
self
.
add_module
(
"ffn_2"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.ffn.2.weight"
,
f
"blocks.
{
self
.
block_index
}
.ffn.2.bias"
))
if
self
.
task
==
"i2v"
:
self
.
cross_attn_k_img
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.k_img.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.k_img.bias"
)
self
.
cross_attn_v_img
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.v_img.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.v_img.bias"
)
self
.
cross_attn_norm_k_img
=
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.norm_k_img.weight"
)
self
.
weight_list
.
append
(
self
.
cross_attn_k_img
)
self
.
weight_list
.
append
(
self
.
cross_attn_v_img
)
self
.
weight_list
.
append
(
self
.
cross_attn_norm_k_img
)
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
RMSWeightTemplate
)):
weight
.
set_config
(
self
.
config
[
"mm_config"
])
weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
RMSWeightTemplate
)):
weight
.
to_cpu
()
self
.
modulation
=
self
.
modulation
.
cpu
()
def
to_cuda
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
RMSWeightTemplate
)):
weight
.
to_cuda
()
self
.
modulation
=
self
.
modulation
.
cuda
()
def
to_cpu_sync
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
RMSWeightTemplate
)):
weight
.
to_cpu
(
non_blocking
=
True
)
self
.
modulation
=
self
.
modulation
.
to
(
"cpu"
,
non_blocking
=
True
)
self
.
add_module
(
"cross_attn_k_img"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.k_img.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.k_img.bias"
))
self
.
add_module
(
"cross_attn_v_img"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.v_img.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.v_img.bias"
))
self
.
add_module
(
"cross_attn_norm_k_img"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.norm_k_img.weight"
))
def
to_cuda_sync
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
RMSWeightTemplate
)):
weight
.
to_cuda
(
non_blocking
=
True
)
self
.
modulation
=
self
.
modulation
.
cuda
(
non_blocking
=
True
)
self
.
register_parameter
(
"modulation"
,
TENSOR_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.modulation"
))
lightx2v/utils/registry_factory.py
View file @
ae96fdbf
...
...
@@ -50,4 +50,6 @@ LN_WEIGHT_REGISTER = Register()
CONV3D_WEIGHT_REGISTER
=
Register
()
CONV2D_WEIGHT_REGISTER
=
Register
()
TENSOR_REGISTER
=
Register
()
RUNNER_REGISTER
=
Register
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment