Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
ae96fdbf
Commit
ae96fdbf
authored
Apr 29, 2025
by
helloyongyang
Browse files
Update weight modules. Simplify code.
parent
3996d421
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
285 additions
and
429 deletions
+285
-429
configs/wan_t2v.json
configs/wan_t2v.json
+2
-0
lightx2v/common/modules/__init__.py
lightx2v/common/modules/__init__.py
+0
-0
lightx2v/common/modules/weight_module.py
lightx2v/common/modules/weight_module.py
+103
-0
lightx2v/common/ops/__init__.py
lightx2v/common/ops/__init__.py
+1
-0
lightx2v/common/ops/norm/__init__.py
lightx2v/common/ops/norm/__init__.py
+1
-0
lightx2v/common/ops/tensor/__init__.py
lightx2v/common/ops/tensor/__init__.py
+1
-0
lightx2v/common/ops/tensor/tensor.py
lightx2v/common/ops/tensor/tensor.py
+16
-0
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
+6
-6
lightx2v/models/networks/hunyuan/model.py
lightx2v/models/networks/hunyuan/model.py
+3
-3
lightx2v/models/networks/hunyuan/weights/post_weights.py
lightx2v/models/networks/hunyuan/weights/post_weights.py
+5
-25
lightx2v/models/networks/hunyuan/weights/pre_weights.py
lightx2v/models/networks/hunyuan/weights/pre_weights.py
+57
-89
lightx2v/models/networks/hunyuan/weights/transformer_weights.py
...2v/models/networks/hunyuan/weights/transformer_weights.py
+29
-127
lightx2v/models/networks/wan/infer/post_infer.py
lightx2v/models/networks/wan/infer/post_infer.py
+2
-2
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+5
-5
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+4
-7
lightx2v/models/networks/wan/weights/post_weights.py
lightx2v/models/networks/wan/weights/post_weights.py
+6
-29
lightx2v/models/networks/wan/weights/pre_weights.py
lightx2v/models/networks/wan/weights/pre_weights.py
+15
-48
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+27
-88
lightx2v/utils/registry_factory.py
lightx2v/utils/registry_factory.py
+2
-0
No files found.
configs/wan_t2v.json
View file @
ae96fdbf
...
@@ -8,6 +8,8 @@
...
@@ -8,6 +8,8 @@
"seed"
:
42
,
"seed"
:
42
,
"sample_guide_scale"
:
6
,
"sample_guide_scale"
:
6
,
"sample_shift"
:
8
,
"sample_shift"
:
8
,
"enable_cfg"
:
true
,
"cpu_offload"
:
false
,
"mm_config"
:
{
"mm_config"
:
{
"mm_type"
:
"W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
,
"mm_type"
:
"W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
,
"weight_auto_quant"
:
true
"weight_auto_quant"
:
true
...
...
lightx2v/common/modules/__init__.py
0 → 100755
View file @
ae96fdbf
lightx2v/common/modules/weight_module.py
0 → 100644
View file @
ae96fdbf
class
WeightModule
:
def
__init__
(
self
):
self
.
_modules
=
{}
self
.
_parameters
=
{}
def
add_module
(
self
,
name
,
module
):
self
.
_modules
[
name
]
=
module
setattr
(
self
,
name
,
module
)
def
register_parameter
(
self
,
name
,
param
):
self
.
_parameters
[
name
]
=
param
setattr
(
self
,
name
,
param
)
def
load
(
self
,
weight_dict
):
for
_
,
module
in
self
.
_modules
.
items
():
if
hasattr
(
module
,
"set_config"
):
module
.
set_config
(
self
.
config
[
"mm_config"
])
if
hasattr
(
module
,
"load"
):
module
.
load
(
weight_dict
)
for
_
,
parameter
in
self
.
_parameters
.
items
():
if
hasattr
(
parameter
,
"set_config"
):
parameter
.
set_config
(
self
.
config
[
"mm_config"
])
if
hasattr
(
parameter
,
"load"
):
parameter
.
load
(
weight_dict
)
def
state_dict
(
self
,
destination
=
None
,
prefix
=
""
):
if
destination
is
None
:
destination
=
{}
for
name
,
param
in
self
.
_parameters
.
items
():
if
param
is
not
None
:
destination
[
prefix
+
name
]
=
param
.
detach
().
cpu
().
clone
()
for
name
,
module
in
self
.
_modules
.
items
():
if
module
is
not
None
:
module
.
state_dict
(
destination
,
prefix
+
name
+
"."
)
return
destination
def
named_parameters
(
self
,
prefix
=
""
):
for
name
,
param
in
self
.
_parameters
.
items
():
if
param
is
not
None
:
yield
prefix
+
name
,
param
for
name
,
module
in
self
.
_modules
.
items
():
if
module
is
not
None
:
yield
from
module
.
named_parameters
(
prefix
+
name
+
"."
)
def
to_cpu
(
self
):
for
name
,
param
in
self
.
_parameters
.
items
():
if
param
is
not
None
and
hasattr
(
param
,
"cpu"
):
self
.
_parameters
[
name
]
=
param
.
cpu
()
setattr
(
self
,
name
,
self
.
_parameters
[
name
])
for
module
in
self
.
_modules
.
values
():
if
module
is
not
None
and
hasattr
(
module
,
"to_cpu"
):
module
.
to_cpu
()
def
to_cuda
(
self
):
for
name
,
param
in
self
.
_parameters
.
items
():
if
param
is
not
None
and
hasattr
(
param
,
"cuda"
):
self
.
_parameters
[
name
]
=
param
.
cuda
()
setattr
(
self
,
name
,
self
.
_parameters
[
name
])
for
module
in
self
.
_modules
.
values
():
if
module
is
not
None
and
hasattr
(
module
,
"to_cuda"
):
module
.
to_cuda
()
def
to_cpu_sync
(
self
):
for
name
,
param
in
self
.
_parameters
.
items
():
if
param
is
not
None
and
hasattr
(
param
,
"to"
):
self
.
_parameters
[
name
]
=
param
.
to
(
"cpu"
,
non_blocking
=
True
)
setattr
(
self
,
name
,
self
.
_parameters
[
name
])
for
module
in
self
.
_modules
.
values
():
if
module
is
not
None
and
hasattr
(
module
,
"to_cpu_sync"
):
module
.
to_cpu_sync
()
def
to_cuda_sync
(
self
):
for
name
,
param
in
self
.
_parameters
.
items
():
if
param
is
not
None
and
hasattr
(
param
,
"cuda"
):
self
.
_parameters
[
name
]
=
param
.
cuda
(
non_blocking
=
True
)
setattr
(
self
,
name
,
self
.
_parameters
[
name
])
for
module
in
self
.
_modules
.
values
():
if
module
is
not
None
and
hasattr
(
module
,
"to_cuda_sync"
):
module
.
to_cuda_sync
()
class
WeightModuleList
(
WeightModule
):
def
__init__
(
self
,
modules
=
None
):
super
().
__init__
()
self
.
_list
=
[]
if
modules
is
not
None
:
for
idx
,
module
in
enumerate
(
modules
):
self
.
append
(
module
)
def
append
(
self
,
module
):
idx
=
len
(
self
.
_list
)
self
.
_list
.
append
(
module
)
self
.
add_module
(
str
(
idx
),
module
)
def
__getitem__
(
self
,
idx
):
return
self
.
_list
[
idx
]
def
__len__
(
self
):
return
len
(
self
.
_list
)
def
__iter__
(
self
):
return
iter
(
self
.
_list
)
lightx2v/common/ops/__init__.py
View file @
ae96fdbf
from
.mm
import
*
from
.mm
import
*
from
.norm
import
*
from
.norm
import
*
from
.conv
import
*
from
.conv
import
*
from
.tensor
import
*
lightx2v/common/ops/norm/__init__.py
View file @
ae96fdbf
from
.rms_norm_weight
import
*
from
.rms_norm_weight
import
*
from
.layer_norm_weight
import
*
lightx2v/common/ops/tensor/__init__.py
0 → 100755
View file @
ae96fdbf
from
.tensor
import
DefaultTensor
lightx2v/common/ops/tensor/tensor.py
0 → 100644
View file @
ae96fdbf
from
lightx2v.utils.registry_factory
import
TENSOR_REGISTER
@
TENSOR_REGISTER
(
"Default"
)
class
DefaultTensor
:
def
__init__
(
self
,
tensor_name
):
self
.
tensor_name
=
tensor_name
def
load
(
self
,
weight_dict
):
self
.
tensor
=
weight_dict
[
self
.
tensor_name
]
def
to_cpu
(
self
,
non_blocking
=
False
):
self
.
tensor
=
self
.
tensor
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
tensor
=
self
.
tensor
.
cuda
(
non_blocking
=
non_blocking
)
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
View file @
ae96fdbf
...
@@ -36,14 +36,14 @@ class HunyuanTransformerInfer:
...
@@ -36,14 +36,14 @@ class HunyuanTransformerInfer:
for
double_block_idx
in
range
(
self
.
double_blocks_num
):
for
double_block_idx
in
range
(
self
.
double_blocks_num
):
if
double_block_idx
==
0
:
if
double_block_idx
==
0
:
self
.
double_weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
double_blocks
_weights
[
0
]
self
.
double_weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
double_blocks
[
0
]
self
.
double_weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
self
.
double_weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
with
torch
.
cuda
.
stream
(
self
.
double_weights_stream_mgr
.
compute_stream
):
with
torch
.
cuda
.
stream
(
self
.
double_weights_stream_mgr
.
compute_stream
):
img
,
txt
=
self
.
infer_double_block
(
self
.
double_weights_stream_mgr
.
active_weights
[
0
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
img
,
txt
=
self
.
infer_double_block
(
self
.
double_weights_stream_mgr
.
active_weights
[
0
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
if
double_block_idx
<
self
.
double_blocks_num
-
1
:
if
double_block_idx
<
self
.
double_blocks_num
-
1
:
self
.
double_weights_stream_mgr
.
prefetch_weights
(
double_block_idx
+
1
,
weights
.
double_blocks
_weights
)
self
.
double_weights_stream_mgr
.
prefetch_weights
(
double_block_idx
+
1
,
weights
.
double_blocks
)
self
.
double_weights_stream_mgr
.
swap_weights
()
self
.
double_weights_stream_mgr
.
swap_weights
()
x
=
torch
.
cat
((
img
,
txt
),
0
)
x
=
torch
.
cat
((
img
,
txt
),
0
)
...
@@ -55,12 +55,12 @@ class HunyuanTransformerInfer:
...
@@ -55,12 +55,12 @@ class HunyuanTransformerInfer:
for
single_block_idx
in
range
(
self
.
single_blocks_num
):
for
single_block_idx
in
range
(
self
.
single_blocks_num
):
if
single_block_idx
==
0
:
if
single_block_idx
==
0
:
self
.
single_weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
single_blocks
_weights
[
0
]
self
.
single_weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
single_blocks
[
0
]
self
.
single_weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
self
.
single_weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
with
torch
.
cuda
.
stream
(
self
.
single_weights_stream_mgr
.
compute_stream
):
with
torch
.
cuda
.
stream
(
self
.
single_weights_stream_mgr
.
compute_stream
):
x
=
self
.
infer_single_block
(
self
.
single_weights_stream_mgr
.
active_weights
[
0
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
x
=
self
.
infer_single_block
(
self
.
single_weights_stream_mgr
.
active_weights
[
0
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
if
single_block_idx
<
self
.
single_blocks_num
-
1
:
if
single_block_idx
<
self
.
single_blocks_num
-
1
:
self
.
single_weights_stream_mgr
.
prefetch_weights
(
single_block_idx
+
1
,
weights
.
single_blocks
_weights
)
self
.
single_weights_stream_mgr
.
prefetch_weights
(
single_block_idx
+
1
,
weights
.
single_blocks
)
self
.
single_weights_stream_mgr
.
swap_weights
()
self
.
single_weights_stream_mgr
.
swap_weights
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -72,12 +72,12 @@ class HunyuanTransformerInfer:
...
@@ -72,12 +72,12 @@ class HunyuanTransformerInfer:
img_seq_len
=
img
.
shape
[
0
]
img_seq_len
=
img
.
shape
[
0
]
for
i
in
range
(
self
.
double_blocks_num
):
for
i
in
range
(
self
.
double_blocks_num
):
img
,
txt
=
self
.
infer_double_block
(
weights
.
double_blocks
_weights
[
i
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
img
,
txt
=
self
.
infer_double_block
(
weights
.
double_blocks
[
i
],
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
x
=
torch
.
cat
((
img
,
txt
),
0
)
x
=
torch
.
cat
((
img
,
txt
),
0
)
for
i
in
range
(
self
.
single_blocks_num
):
for
i
in
range
(
self
.
single_blocks_num
):
x
=
self
.
infer_single_block
(
weights
.
single_blocks
_weights
[
i
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
x
=
self
.
infer_single_block
(
weights
.
single_blocks
[
i
],
x
,
vec
,
txt_seq_len
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
img
=
x
[:
img_seq_len
,
...]
img
=
x
[:
img_seq_len
,
...]
return
img
,
vec
return
img
,
vec
...
...
lightx2v/models/networks/hunyuan/model.py
View file @
ae96fdbf
...
@@ -64,9 +64,9 @@ class HunyuanModel:
...
@@ -64,9 +64,9 @@ class HunyuanModel:
self
.
post_weight
=
self
.
post_weight_class
(
self
.
config
)
self
.
post_weight
=
self
.
post_weight_class
(
self
.
config
)
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
)
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
)
# load weights
# load weights
self
.
pre_weight
.
load
_weights
(
weight_dict
)
self
.
pre_weight
.
load
(
weight_dict
)
self
.
post_weight
.
load
_weights
(
weight_dict
)
self
.
post_weight
.
load
(
weight_dict
)
self
.
transformer_weights
.
load
_weights
(
weight_dict
)
self
.
transformer_weights
.
load
(
weight_dict
)
def
_init_infer
(
self
):
def
_init_infer
(
self
):
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
...
...
lightx2v/models/networks/hunyuan/weights/post_weights.py
View file @
ae96fdbf
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v.common.
ops.mm.mm_weight
import
MM
Weight
Templat
e
from
lightx2v.common.
modules.weight_module
import
Weight
Modul
e
class
HunyuanPostWeights
:
class
HunyuanPostWeights
(
WeightModule
)
:
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
def
load_weights
(
self
,
weight_dict
):
self
.
add_module
(
"final_layer_linear"
,
MM_WEIGHT_REGISTER
[
"Default-Force-FP32"
](
"final_layer.linear.weight"
,
"final_layer.linear.bias"
))
self
.
final_layer_linear
=
MM_WEIGHT_REGISTER
[
"Default-Force-FP32"
](
"final_layer.linear.weight"
,
"final_layer.linear.bias"
)
self
.
add_module
(
"final_layer_adaLN_modulation_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"final_layer.adaLN_modulation.1.weight"
,
"final_layer.adaLN_modulation.1.bias"
))
self
.
final_layer_adaLN_modulation_1
=
MM_WEIGHT_REGISTER
[
"Default"
](
"final_layer.adaLN_modulation.1.weight"
,
"final_layer.adaLN_modulation.1.bias"
)
self
.
weight_list
=
[
self
.
final_layer_linear
,
self
.
final_layer_adaLN_modulation_1
,
]
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
MMWeightTemplate
):
weight
.
set_config
(
self
.
config
[
"mm_config"
])
weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
MMWeightTemplate
):
weight
.
to_cpu
()
def
to_cuda
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
MMWeightTemplate
):
weight
.
to_cuda
()
lightx2v/models/networks/hunyuan/weights/pre_weights.py
View file @
ae96fdbf
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
CONV3D_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
CONV3D_WEIGHT_REGISTER
from
lightx2v.common.ops.mm.mm_weight
import
MMWeightTemplate
from
lightx2v.common.modules.weight_module
import
WeightModule
from
lightx2v.common.ops.norm.layer_norm_weight
import
LNWeightTemplate
from
lightx2v.common.ops.conv.conv3d
import
Conv3dWeightTemplate
class
HunyuanPreWeights
:
class
HunyuanPreWeights
(
WeightModule
)
:
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
def
load_weights
(
self
,
weight_dict
):
self
.
add_module
(
"img_in_proj"
,
CONV3D_WEIGHT_REGISTER
[
"Default"
](
"img_in.proj.weight"
,
"img_in.proj.bias"
,
stride
=
(
1
,
2
,
2
)))
self
.
img_in_proj
=
CONV3D_WEIGHT_REGISTER
[
"Default"
](
"img_in.proj.weight"
,
"img_in.proj.bias"
,
stride
=
(
1
,
2
,
2
))
self
.
txt_in_input_embedder
=
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.input_embedder.weight"
,
"txt_in.input_embedder.bias"
)
self
.
add_module
(
"
txt_in_input_embedder
"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.input_embedder.weight"
,
"txt_in.input_embedder.bias"
)
)
self
.
txt_in_t_embedder_mlp_0
=
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.t_embedder.mlp.0.weight"
,
"txt_in.t_embedder.mlp.0.bias"
)
self
.
add_module
(
"
txt_in_t_embedder_mlp_0
"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.t_embedder.mlp.0.weight"
,
"txt_in.t_embedder.mlp.0.bias"
)
)
self
.
txt_in_t_embedder_mlp_2
=
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.t_embedder.mlp.2.weight"
,
"txt_in.t_embedder.mlp.2.bias"
)
self
.
add_module
(
"
txt_in_t_embedder_mlp_2
"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.t_embedder.mlp.2.weight"
,
"txt_in.t_embedder.mlp.2.bias"
)
)
self
.
txt_in_c_embedder_linear_1
=
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.c_embedder.linear_1.weight"
,
"txt_in.c_embedder.linear_1.bias"
)
self
.
add_module
(
"
txt_in_c_embedder_linear_1
"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.c_embedder.linear_1.weight"
,
"txt_in.c_embedder.linear_1.bias"
)
)
self
.
txt_in_c_embedder_linear_2
=
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.c_embedder.linear_2.weight"
,
"txt_in.c_embedder.linear_2.bias"
)
self
.
add_module
(
"
txt_in_c_embedder_linear_2
"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.c_embedder.linear_2.weight"
,
"txt_in.c_embedder.linear_2.bias"
)
)
self
.
txt_in_individual_token_refiner_blocks_0_norm1
=
LN_WEIGHT_REGISTER
[
"Default"
](
self
.
add_module
(
"txt_in.individual_token_refiner.blocks.0.norm1.weight"
,
"txt_in.individual_token_refiner.blocks.0.norm1.bias"
,
eps
=
1e-6
"txt_in_individual_token_refiner_blocks_0_norm1"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.norm1.weight"
,
"txt_in.individual_token_refiner.blocks.0.norm1.bias"
,
eps
=
1e-6
),
)
)
self
.
txt_in_individual_token_refiner_blocks_0_self_attn_qkv
=
MM_WEIGHT_REGISTER
[
"Default"
](
self
.
add_module
(
"txt_in.individual_token_refiner.blocks.0.self_attn_qkv.weight"
,
"txt_in.individual_token_refiner.blocks.0.self_attn_qkv.bias"
"txt_in_individual_token_refiner_blocks_0_self_attn_qkv"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.self_attn_qkv.weight"
,
"txt_in.individual_token_refiner.blocks.0.self_attn_qkv.bias"
),
)
)
self
.
txt_in_individual_token_refiner_blocks_0_self_attn_proj
=
MM_WEIGHT_REGISTER
[
"Default"
](
self
.
add_module
(
"txt_in.individual_token_refiner.blocks.0.self_attn_proj.weight"
,
"txt_in.individual_token_refiner.blocks.0.self_attn_proj.bias"
"txt_in_individual_token_refiner_blocks_0_self_attn_proj"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.self_attn_proj.weight"
,
"txt_in.individual_token_refiner.blocks.0.self_attn_proj.bias"
),
)
)
self
.
txt_in_individual_token_refiner_blocks_0_norm2
=
LN_WEIGHT_REGISTER
[
"Default"
](
self
.
add_module
(
"txt_in.individual_token_refiner.blocks.0.norm2.weight"
,
"txt_in.individual_token_refiner.blocks.0.norm2.bias"
,
eps
=
1e-6
"txt_in_individual_token_refiner_blocks_0_norm2"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.norm2.weight"
,
"txt_in.individual_token_refiner.blocks.0.norm2.bias"
,
eps
=
1e-6
),
)
)
self
.
txt_in_individual_token_refiner_blocks_0_mlp_fc1
=
MM_WEIGHT_REGISTER
[
"Default"
](
self
.
add_module
(
"txt_in.individual_token_refiner.blocks.0.mlp.fc1.weight"
,
"txt_in.individual_token_refiner.blocks.0.mlp.fc1.bias"
"txt_in_individual_token_refiner_blocks_0_mlp_fc1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.mlp.fc1.weight"
,
"txt_in.individual_token_refiner.blocks.0.mlp.fc1.bias"
),
)
)
self
.
txt_in_individual_token_refiner_blocks_0_mlp_fc2
=
MM_WEIGHT_REGISTER
[
"Default"
](
self
.
add_module
(
"txt_in.individual_token_refiner.blocks.0.mlp.fc2.weight"
,
"txt_in.individual_token_refiner.blocks.0.mlp.fc2.bias"
"txt_in_individual_token_refiner_blocks_0_mlp_fc2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.mlp.fc2.weight"
,
"txt_in.individual_token_refiner.blocks.0.mlp.fc2.bias"
),
)
)
self
.
txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1
=
MM_WEIGHT_REGISTER
[
"Default"
](
self
.
add_module
(
"txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.weight"
,
"txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias"
"txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.weight"
,
"txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias"
),
)
)
self
.
txt_in_individual_token_refiner_blocks_1_norm1
=
LN_WEIGHT_REGISTER
[
"Default"
](
self
.
add_module
(
"txt_in.individual_token_refiner.blocks.1.norm1.weight"
,
"txt_in.individual_token_refiner.blocks.1.norm1.bias"
,
eps
=
1e-6
"txt_in_individual_token_refiner_blocks_1_norm1"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.norm1.weight"
,
"txt_in.individual_token_refiner.blocks.1.norm1.bias"
,
eps
=
1e-6
),
)
)
self
.
txt_in_individual_token_refiner_blocks_1_self_attn_qkv
=
MM_WEIGHT_REGISTER
[
"Default"
](
self
.
add_module
(
"txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight"
,
"txt_in.individual_token_refiner.blocks.1.self_attn_qkv.bias"
"txt_in_individual_token_refiner_blocks_1_self_attn_qkv"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight"
,
"txt_in.individual_token_refiner.blocks.1.self_attn_qkv.bias"
),
)
)
self
.
txt_in_individual_token_refiner_blocks_1_self_attn_proj
=
MM_WEIGHT_REGISTER
[
"Default"
](
self
.
add_module
(
"txt_in.individual_token_refiner.blocks.1.self_attn_proj.weight"
,
"txt_in.individual_token_refiner.blocks.1.self_attn_proj.bias"
"txt_in_individual_token_refiner_blocks_1_self_attn_proj"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.self_attn_proj.weight"
,
"txt_in.individual_token_refiner.blocks.1.self_attn_proj.bias"
),
)
)
self
.
txt_in_individual_token_refiner_blocks_1_norm2
=
LN_WEIGHT_REGISTER
[
"Default"
](
self
.
add_module
(
"txt_in.individual_token_refiner.blocks.1.norm2.weight"
,
"txt_in.individual_token_refiner.blocks.1.norm2.bias"
,
eps
=
1e-6
"txt_in_individual_token_refiner_blocks_1_norm2"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.norm2.weight"
,
"txt_in.individual_token_refiner.blocks.1.norm2.bias"
,
eps
=
1e-6
),
)
)
self
.
txt_in_individual_token_refiner_blocks_1_mlp_fc1
=
MM_WEIGHT_REGISTER
[
"Default"
](
self
.
add_module
(
"txt_in.individual_token_refiner.blocks.1.mlp.fc1.weight"
,
"txt_in.individual_token_refiner.blocks.1.mlp.fc1.bias"
"txt_in_individual_token_refiner_blocks_1_mlp_fc1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.mlp.fc1.weight"
,
"txt_in.individual_token_refiner.blocks.1.mlp.fc1.bias"
),
)
)
self
.
txt_in_individual_token_refiner_blocks_1_mlp_fc2
=
MM_WEIGHT_REGISTER
[
"Default"
](
self
.
add_module
(
"txt_in.individual_token_refiner.blocks.1.mlp.fc2.weight"
,
"txt_in.individual_token_refiner.blocks.1.mlp.fc2.bias"
"txt_in_individual_token_refiner_blocks_1_mlp_fc2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.mlp.fc2.weight"
,
"txt_in.individual_token_refiner.blocks.1.mlp.fc2.bias"
),
)
)
self
.
txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1
=
MM_WEIGHT_REGISTER
[
"Default"
](
self
.
add_module
(
"txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.weight"
,
"txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.bias"
"txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.weight"
,
"txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.bias"
),
)
)
self
.
time_in_mlp_0
=
MM_WEIGHT_REGISTER
[
"Default"
](
"time_in.mlp.0.weight"
,
"time_in.mlp.0.bias"
)
self
.
add_module
(
"time_in_mlp_0"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"time_in.mlp.0.weight"
,
"time_in.mlp.0.bias"
))
self
.
time_in_mlp_2
=
MM_WEIGHT_REGISTER
[
"Default"
](
"time_in.mlp.2.weight"
,
"time_in.mlp.2.bias"
)
self
.
add_module
(
"time_in_mlp_2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"time_in.mlp.2.weight"
,
"time_in.mlp.2.bias"
))
self
.
vector_in_in_layer
=
MM_WEIGHT_REGISTER
[
"Default"
](
"vector_in.in_layer.weight"
,
"vector_in.in_layer.bias"
)
self
.
add_module
(
"vector_in_in_layer"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"vector_in.in_layer.weight"
,
"vector_in.in_layer.bias"
))
self
.
vector_in_out_layer
=
MM_WEIGHT_REGISTER
[
"Default"
](
"vector_in.out_layer.weight"
,
"vector_in.out_layer.bias"
)
self
.
add_module
(
"vector_in_out_layer"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"vector_in.out_layer.weight"
,
"vector_in.out_layer.bias"
))
self
.
guidance_in_mlp_0
=
MM_WEIGHT_REGISTER
[
"Default"
](
"guidance_in.mlp.0.weight"
,
"guidance_in.mlp.0.bias"
)
self
.
add_module
(
"guidance_in_mlp_0"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"guidance_in.mlp.0.weight"
,
"guidance_in.mlp.0.bias"
))
self
.
guidance_in_mlp_2
=
MM_WEIGHT_REGISTER
[
"Default"
](
"guidance_in.mlp.2.weight"
,
"guidance_in.mlp.2.bias"
)
self
.
add_module
(
"guidance_in_mlp_2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"guidance_in.mlp.2.weight"
,
"guidance_in.mlp.2.bias"
))
self
.
weight_list
=
[
self
.
img_in_proj
,
self
.
txt_in_input_embedder
,
self
.
txt_in_t_embedder_mlp_0
,
self
.
txt_in_t_embedder_mlp_2
,
self
.
txt_in_c_embedder_linear_1
,
self
.
txt_in_c_embedder_linear_2
,
self
.
txt_in_individual_token_refiner_blocks_0_norm1
,
self
.
txt_in_individual_token_refiner_blocks_0_self_attn_qkv
,
self
.
txt_in_individual_token_refiner_blocks_0_self_attn_proj
,
self
.
txt_in_individual_token_refiner_blocks_0_norm2
,
self
.
txt_in_individual_token_refiner_blocks_0_mlp_fc1
,
self
.
txt_in_individual_token_refiner_blocks_0_mlp_fc2
,
self
.
txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1
,
self
.
txt_in_individual_token_refiner_blocks_1_norm1
,
self
.
txt_in_individual_token_refiner_blocks_1_self_attn_qkv
,
self
.
txt_in_individual_token_refiner_blocks_1_self_attn_proj
,
self
.
txt_in_individual_token_refiner_blocks_1_norm2
,
self
.
txt_in_individual_token_refiner_blocks_1_mlp_fc1
,
self
.
txt_in_individual_token_refiner_blocks_1_mlp_fc2
,
self
.
txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1
,
self
.
time_in_mlp_0
,
self
.
time_in_mlp_2
,
self
.
vector_in_in_layer
,
self
.
vector_in_out_layer
,
self
.
guidance_in_mlp_0
,
self
.
guidance_in_mlp_2
,
]
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
MMWeightTemplate
)
or
isinstance
(
weight
,
LNWeightTemplate
)
or
isinstance
(
weight
,
Conv3dWeightTemplate
):
weight
.
set_config
(
self
.
config
[
"mm_config"
])
weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
MMWeightTemplate
)
or
isinstance
(
weight
,
LNWeightTemplate
)
or
isinstance
(
weight
,
Conv3dWeightTemplate
):
weight
.
to_cpu
()
def
to_cuda
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
MMWeightTemplate
)
or
isinstance
(
weight
,
LNWeightTemplate
)
or
isinstance
(
weight
,
Conv3dWeightTemplate
):
weight
.
to_cuda
()
lightx2v/models/networks/hunyuan/weights/transformer_weights.py
View file @
ae96fdbf
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
RMS_WEIGHT_REGISTER
from
lightx2v.common.ops.norm.rms_norm_weight
import
RMS_WEIGHT_REGISTER
from
lightx2v.common.modules.weight_module
import
WeightModule
,
WeightModuleList
from
lightx2v.common.ops.mm.mm_weight
import
MMWeightTemplate
from
lightx2v.common.ops.norm.rms_norm_weight
import
RMSWeightTemplate
class
HunyuanTransformerWeights
:
class
HunyuanTransformerWeights
(
WeightModule
)
:
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
init
()
def
init
(
self
):
self
.
double_blocks_num
=
20
self
.
double_blocks_num
=
20
self
.
single_blocks_num
=
40
self
.
single_blocks_num
=
40
def
load_weights
(
self
,
weight_dict
):
self
.
add_module
(
"double_blocks"
,
WeightModuleList
([
HunyuanTransformerDoubleBlock
(
i
,
self
.
config
)
for
i
in
range
(
self
.
double_blocks_num
)]))
self
.
double_blocks_weights
=
[
HunyuanTransformerDoubleBlock
(
i
,
self
.
config
)
for
i
in
range
(
self
.
double_blocks_num
)]
self
.
add_module
(
"single_blocks"
,
WeightModuleList
([
HunyuanTransformerSingleBlock
(
i
,
self
.
config
)
for
i
in
range
(
self
.
single_blocks_num
)]))
self
.
single_blocks_weights
=
[
HunyuanTransformerSingleBlock
(
i
,
self
.
config
)
for
i
in
range
(
self
.
single_blocks_num
)]
for
double_block
in
self
.
double_blocks_weights
:
double_block
.
load_weights
(
weight_dict
)
for
single_block
in
self
.
single_blocks_weights
:
single_block
.
load_weights
(
weight_dict
)
def
to_cpu
(
self
):
for
double_block
in
self
.
double_blocks_weights
:
double_block
.
to_cpu
()
for
single_block
in
self
.
single_blocks_weights
:
single_block
.
to_cpu
()
def
to_cuda
(
self
):
for
double_block
in
self
.
double_blocks_weights
:
double_block
.
to_cuda
()
for
single_block
in
self
.
single_blocks_weights
:
single_block
.
to_cuda
()
class
HunyuanTransformerDoubleBlock
(
WeightModule
):
class
HunyuanTransformerDoubleBlock
:
def
__init__
(
self
,
block_index
,
config
):
def
__init__
(
self
,
block_index
,
config
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
self
.
config
=
config
self
.
config
=
config
self
.
weight_list
=
[]
def
load_weights
(
self
,
weight_dict
):
if
self
.
config
[
"do_mm_calib"
]:
if
self
.
config
[
"do_mm_calib"
]:
mm_type
=
"Calib"
mm_type
=
"Calib"
else
:
else
:
mm_type
=
self
.
config
[
"mm_config"
].
get
(
"mm_type"
,
"Default"
)
if
self
.
config
[
"mm_config"
]
else
"Default"
mm_type
=
self
.
config
[
"mm_config"
].
get
(
"mm_type"
,
"Default"
)
if
self
.
config
[
"mm_config"
]
else
"Default"
self
.
img_mod
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_mod.linear.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_mod.linear.bias"
)
self
.
add_module
(
"img_mod"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_mod.linear.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_mod.linear.bias"
))
self
.
img_attn_qkv
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_qkv.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_attn_qkv.bias"
)
self
.
add_module
(
"img_attn_qkv"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_qkv.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_attn_qkv.bias"
))
self
.
img_attn_q_norm
=
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_q_norm.weight"
,
eps
=
1e-6
)
self
.
add_module
(
"img_attn_q_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_q_norm.weight"
,
eps
=
1e-6
))
self
.
img_attn_k_norm
=
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_k_norm.weight"
,
eps
=
1e-6
)
self
.
add_module
(
"img_attn_k_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_k_norm.weight"
,
eps
=
1e-6
))
self
.
img_attn_proj
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_proj.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_attn_proj.bias"
)
self
.
add_module
(
"img_attn_proj"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_attn_proj.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_attn_proj.bias"
))
self
.
img_mlp_fc1
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc1.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc1.bias"
)
self
.
add_module
(
"img_mlp_fc1"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc1.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc1.bias"
))
self
.
img_mlp_fc2
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc2.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc2.bias"
)
self
.
add_module
(
"img_mlp_fc2"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc2.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.img_mlp.fc2.bias"
))
self
.
txt_mod
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_mod.linear.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_mod.linear.bias"
)
self
.
txt_attn_qkv
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_qkv.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_qkv.bias"
)
self
.
txt_attn_q_norm
=
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_q_norm.weight"
,
eps
=
1e-6
)
self
.
txt_attn_k_norm
=
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_k_norm.weight"
,
eps
=
1e-6
)
self
.
txt_attn_proj
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_proj.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_proj.bias"
)
self
.
txt_mlp_fc1
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc1.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc1.bias"
)
self
.
txt_mlp_fc2
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc2.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc2.bias"
)
self
.
weight_list
=
[
self
.
img_mod
,
self
.
img_attn_qkv
,
self
.
img_attn_q_norm
,
self
.
img_attn_k_norm
,
self
.
img_attn_proj
,
self
.
img_mlp_fc1
,
self
.
img_mlp_fc2
,
self
.
txt_mod
,
self
.
txt_attn_qkv
,
self
.
txt_attn_q_norm
,
self
.
txt_attn_k_norm
,
self
.
txt_attn_proj
,
self
.
txt_mlp_fc1
,
self
.
txt_mlp_fc2
,
]
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
weight
.
set_config
(
self
.
config
[
"mm_config"
])
weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
weight
.
to_cpu
()
def
to_cuda
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
weight
.
to_cuda
()
def
to_cpu_sync
(
self
):
self
.
add_module
(
"txt_mod"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_mod.linear.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_mod.linear.bias"
))
for
weight
in
self
.
weight_list
:
self
.
add_module
(
"txt_attn_qkv"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_qkv.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_qkv.bias"
))
if
isinstance
(
weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
self
.
add_module
(
"txt_attn_q_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_q_norm.weight"
,
eps
=
1e-6
))
weight
.
to_cpu
(
non_blocking
=
True
)
self
.
add_module
(
"txt_attn_k_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_k_norm.weight"
,
eps
=
1e-6
))
self
.
add_module
(
"txt_attn_proj"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_proj.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_attn_proj.bias"
))
self
.
add_module
(
"txt_mlp_fc1"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc1.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc1.bias"
))
self
.
add_module
(
"txt_mlp_fc2"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc2.weight"
,
f
"double_blocks.
{
self
.
block_index
}
.txt_mlp.fc2.bias"
))
def
to_cuda_sync
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
weight
.
to_cuda
(
non_blocking
=
True
)
class
HunyuanTransformerSingleBlock
(
WeightModule
):
class
HunyuanTransformerSingleBlock
:
def
__init__
(
self
,
block_index
,
config
):
def
__init__
(
self
,
block_index
,
config
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
self
.
config
=
config
self
.
config
=
config
self
.
weight_list
=
[]
def
load_weights
(
self
,
weight_dict
):
if
self
.
config
[
"do_mm_calib"
]:
if
self
.
config
[
"do_mm_calib"
]:
mm_type
=
"Calib"
mm_type
=
"Calib"
else
:
else
:
mm_type
=
self
.
config
[
"mm_config"
].
get
(
"mm_type"
,
"Default"
)
if
self
.
config
[
"mm_config"
]
else
"Default"
mm_type
=
self
.
config
[
"mm_config"
].
get
(
"mm_type"
,
"Default"
)
if
self
.
config
[
"mm_config"
]
else
"Default"
self
.
linear1
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"single_blocks.
{
self
.
block_index
}
.linear1.weight"
,
f
"single_blocks.
{
self
.
block_index
}
.linear1.bias"
)
self
.
add_module
(
"linear1"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"single_blocks.
{
self
.
block_index
}
.linear1.weight"
,
f
"single_blocks.
{
self
.
block_index
}
.linear1.bias"
))
self
.
linear2
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"single_blocks.
{
self
.
block_index
}
.linear2.weight"
,
f
"single_blocks.
{
self
.
block_index
}
.linear2.bias"
)
self
.
add_module
(
"linear2"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"single_blocks.
{
self
.
block_index
}
.linear2.weight"
,
f
"single_blocks.
{
self
.
block_index
}
.linear2.bias"
))
self
.
q_norm
=
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"single_blocks.
{
self
.
block_index
}
.q_norm.weight"
,
eps
=
1e-6
)
self
.
add_module
(
"q_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"single_blocks.
{
self
.
block_index
}
.q_norm.weight"
,
eps
=
1e-6
))
self
.
k_norm
=
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"single_blocks.
{
self
.
block_index
}
.k_norm.weight"
,
eps
=
1e-6
)
self
.
add_module
(
"k_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"single_blocks.
{
self
.
block_index
}
.k_norm.weight"
,
eps
=
1e-6
))
self
.
modulation
=
MM_WEIGHT_REGISTER
[
mm_type
](
f
"single_blocks.
{
self
.
block_index
}
.modulation.linear.weight"
,
f
"single_blocks.
{
self
.
block_index
}
.modulation.linear.bias"
)
self
.
add_module
(
"modulation"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"single_blocks.
{
self
.
block_index
}
.modulation.linear.weight"
,
f
"single_blocks.
{
self
.
block_index
}
.modulation.linear.bias"
))
self
.
weight_list
=
[
self
.
linear1
,
self
.
linear2
,
self
.
q_norm
,
self
.
k_norm
,
self
.
modulation
,
]
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
weight
.
set_config
(
self
.
config
[
"mm_config"
])
weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
weight
.
to_cpu
()
def
to_cuda
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
weight
.
to_cuda
()
def
to_cpu_sync
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
weight
.
to_cpu
(
non_blocking
=
True
)
def
to_cuda_sync
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
RMSWeightTemplate
)):
weight
.
to_cuda
(
non_blocking
=
True
)
lightx2v/models/networks/wan/infer/post_infer.py
View file @
ae96fdbf
...
@@ -13,10 +13,10 @@ class WanPostInfer:
...
@@ -13,10 +13,10 @@ class WanPostInfer:
def
infer
(
self
,
weights
,
x
,
e
,
grid_sizes
):
def
infer
(
self
,
weights
,
x
,
e
,
grid_sizes
):
if
e
.
dim
()
==
2
:
if
e
.
dim
()
==
2
:
modulation
=
weights
.
head_modulation
# 1, 2, dim
modulation
=
weights
.
head_modulation
.
tensor
# 1, 2, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
elif
e
.
dim
()
==
3
:
# For Diffustion forcing
elif
e
.
dim
()
==
3
:
# For Diffustion forcing
modulation
=
weights
.
head_modulation
.
unsqueeze
(
2
)
# 1, 2, seq, dim
modulation
=
weights
.
head_modulation
.
tensor
.
unsqueeze
(
2
)
# 1, 2, seq, dim
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
e
=
[
ei
.
squeeze
(
1
)
for
ei
in
e
]
e
=
[
ei
.
squeeze
(
1
)
for
ei
in
e
]
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
ae96fdbf
...
@@ -42,7 +42,7 @@ class WanTransformerInfer:
...
@@ -42,7 +42,7 @@ class WanTransformerInfer:
def
_infer_with_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
_infer_with_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
for
block_idx
in
range
(
self
.
blocks_num
):
for
block_idx
in
range
(
self
.
blocks_num
):
if
block_idx
==
0
:
if
block_idx
==
0
:
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
blocks
_weights
[
0
]
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
blocks
[
0
]
self
.
weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
self
.
weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
...
@@ -58,7 +58,7 @@ class WanTransformerInfer:
...
@@ -58,7 +58,7 @@ class WanTransformerInfer:
)
)
if
block_idx
<
self
.
blocks_num
-
1
:
if
block_idx
<
self
.
blocks_num
-
1
:
self
.
weights_stream_mgr
.
prefetch_weights
(
block_idx
+
1
,
weights
.
blocks
_weights
)
self
.
weights_stream_mgr
.
prefetch_weights
(
block_idx
+
1
,
weights
.
blocks
)
self
.
weights_stream_mgr
.
swap_weights
()
self
.
weights_stream_mgr
.
swap_weights
()
return
x
return
x
...
@@ -66,7 +66,7 @@ class WanTransformerInfer:
...
@@ -66,7 +66,7 @@ class WanTransformerInfer:
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
_infer_without_offload
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
for
block_idx
in
range
(
self
.
blocks_num
):
for
block_idx
in
range
(
self
.
blocks_num
):
x
=
self
.
infer_block
(
x
=
self
.
infer_block
(
weights
.
blocks
_weights
[
block_idx
],
weights
.
blocks
[
block_idx
],
grid_sizes
,
grid_sizes
,
embed
,
embed
,
x
,
x
,
...
@@ -79,12 +79,12 @@ class WanTransformerInfer:
...
@@ -79,12 +79,12 @@ class WanTransformerInfer:
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
if
embed0
.
dim
()
==
3
:
if
embed0
.
dim
()
==
3
:
modulation
=
weights
.
modulation
.
unsqueeze
(
2
)
# 1, 6, 1, dim
modulation
=
weights
.
modulation
.
tensor
.
unsqueeze
(
2
)
# 1, 6, 1, dim
embed0
=
embed0
.
unsqueeze
(
0
)
#
embed0
=
embed0
.
unsqueeze
(
0
)
#
embed0
=
(
modulation
+
embed0
).
chunk
(
6
,
dim
=
1
)
embed0
=
(
modulation
+
embed0
).
chunk
(
6
,
dim
=
1
)
embed0
=
[
ei
.
squeeze
(
1
)
for
ei
in
embed0
]
embed0
=
[
ei
.
squeeze
(
1
)
for
ei
in
embed0
]
elif
embed0
.
dim
()
==
2
:
elif
embed0
.
dim
()
==
2
:
embed0
=
(
weights
.
modulation
+
embed0
).
chunk
(
6
,
dim
=
1
)
embed0
=
(
weights
.
modulation
.
tensor
+
embed0
).
chunk
(
6
,
dim
=
1
)
norm1_out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
)
norm1_out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
)
norm1_out
=
(
norm1_out
*
(
1
+
embed0
[
1
])
+
embed0
[
0
]).
squeeze
(
0
)
norm1_out
=
(
norm1_out
*
(
1
+
embed0
[
1
])
+
embed0
[
0
]).
squeeze
(
0
)
...
...
lightx2v/models/networks/wan/model.py
View file @
ae96fdbf
...
@@ -84,9 +84,9 @@ class WanModel:
...
@@ -84,9 +84,9 @@ class WanModel:
self
.
post_weight
=
self
.
post_weight_class
(
self
.
config
)
self
.
post_weight
=
self
.
post_weight_class
(
self
.
config
)
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
)
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
)
# load weights
# load weights
self
.
pre_weight
.
load
_weights
(
self
.
original_weight_dict
)
self
.
pre_weight
.
load
(
self
.
original_weight_dict
)
self
.
post_weight
.
load
_weights
(
self
.
original_weight_dict
)
self
.
post_weight
.
load
(
self
.
original_weight_dict
)
self
.
transformer_weights
.
load
_weights
(
self
.
original_weight_dict
)
self
.
transformer_weights
.
load
(
self
.
original_weight_dict
)
def
_init_infer
(
self
):
def
_init_infer
(
self
):
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
...
@@ -109,9 +109,6 @@ class WanModel:
...
@@ -109,9 +109,6 @@ class WanModel:
self
.
post_weight
.
to_cuda
()
self
.
post_weight
.
to_cuda
()
self
.
transformer_weights
.
to_cuda
()
self
.
transformer_weights
.
to_cuda
()
def
do_classifier_free_guidance
(
self
)
->
bool
:
return
self
.
config
.
sample_guide_scale
>
1
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
def
infer
(
self
,
inputs
):
if
self
.
config
[
"cpu_offload"
]:
if
self
.
config
[
"cpu_offload"
]:
...
@@ -128,7 +125,7 @@ class WanModel:
...
@@ -128,7 +125,7 @@ class WanModel:
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
cnt
=
0
self
.
scheduler
.
noise_pred
=
noise_pred_cond
self
.
scheduler
.
noise_pred
=
noise_pred_cond
if
self
.
do_classifier_free_guidance
()
:
if
self
.
config
[
"enable_cfg"
]
:
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
embed
,
grid_sizes
,
pre_infer_out
=
self
.
pre_infer
.
infer
(
self
.
pre_weight
,
inputs
,
positive
=
False
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
x
=
self
.
transformer_infer
.
infer
(
self
.
transformer_weights
,
grid_sizes
,
embed
,
*
pre_infer_out
)
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
noise_pred_uncond
=
self
.
post_infer
.
infer
(
self
.
post_weight
,
x
,
embed
,
grid_sizes
)[
0
]
...
...
lightx2v/models/networks/wan/weights/post_weights.py
View file @
ae96fdbf
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
TENSOR_REGISTER
from
lightx2v.common.
ops.mm.mm_weight
import
MM
Weight
Templat
e
from
lightx2v.common.
modules.weight_module
import
Weight
Modul
e
class
WanPostWeights
:
class
WanPostWeights
(
WeightModule
)
:
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
add_module
(
"head"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"head.head.weight"
,
"head.head.bias"
))
def
load_weights
(
self
,
weight_dict
):
self
.
register_parameter
(
"head_modulation"
,
TENSOR_REGISTER
[
"Default"
](
"head.modulation"
))
self
.
head
=
MM_WEIGHT_REGISTER
[
"Default"
](
"head.head.weight"
,
"head.head.bias"
)
self
.
head_modulation
=
weight_dict
[
"head.modulation"
]
self
.
weight_list
=
[
self
.
head
]
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
MMWeightTemplate
):
weight
.
set_config
(
self
.
config
[
"mm_config"
])
weight
.
load
(
weight_dict
)
if
self
.
config
[
"cpu_offload"
]:
weight
.
to_cpu
()
self
.
head_modulation
=
self
.
head_modulation
.
cpu
()
def
to_cpu
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
MMWeightTemplate
):
weight
.
to_cpu
()
self
.
head_modulation
=
self
.
head_modulation
.
cpu
()
def
to_cuda
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
MMWeightTemplate
):
weight
.
to_cuda
()
self
.
head_modulation
=
self
.
head_modulation
.
cuda
()
lightx2v/models/networks/wan/weights/pre_weights.py
View file @
ae96fdbf
import
torch
import
torch
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
CONV3D_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
CONV3D_WEIGHT_REGISTER
from
lightx2v.common.ops.mm.mm_weight
import
MMWeightTemplate
from
lightx2v.common.modules.weight_module
import
WeightModule
from
lightx2v.common.ops.norm.layer_norm_weight
import
LNWeightTemplate
from
lightx2v.common.ops.conv.conv3d
import
Conv3dWeightTemplate
class
WanPreWeights
:
class
WanPreWeights
(
WeightModule
)
:
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
in_dim
=
config
[
"in_dim"
]
self
.
in_dim
=
config
[
"in_dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
config
=
config
self
.
config
=
config
def
load_weights
(
self
,
weight_dict
):
self
.
add_module
(
"patch_embedding"
,
CONV3D_WEIGHT_REGISTER
[
"Defaultt-Force-BF16"
](
"patch_embedding.weight"
,
"patch_embedding.bias"
,
stride
=
self
.
patch_size
))
self
.
patch_embedding
=
CONV3D_WEIGHT_REGISTER
[
"Defaultt-Force-BF16"
](
"patch_embedding.weight"
,
"patch_embedding.bias"
,
stride
=
self
.
patch_size
)
self
.
add_module
(
"text_embedding_0"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"text_embedding.0.weight"
,
"text_embedding.0.bias"
))
self
.
add_module
(
"text_embedding_2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"text_embedding.2.weight"
,
"text_embedding.2.bias"
))
self
.
text_embedding_0
=
MM_WEIGHT_REGISTER
[
"Default"
](
"text_embedding.0.weight"
,
"text_embedding.0.bias"
)
self
.
add_module
(
"time_embedding_0"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"time_embedding.0.weight"
,
"time_embedding.0.bias"
))
self
.
text_embedding_2
=
MM_WEIGHT_REGISTER
[
"Default"
](
"text_embedding.2.weight"
,
"text_embedding.2.bias"
)
self
.
add_module
(
"time_embedding_2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"time_embedding.2.weight"
,
"time_embedding.2.bias"
))
self
.
time_embedding_0
=
MM_WEIGHT_REGISTER
[
"Default"
](
"time_embedding.0.weight"
,
"time_embedding.0.bias"
)
self
.
add_module
(
"time_projection_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"time_projection.1.weight"
,
"time_projection.1.bias"
))
self
.
time_embedding_2
=
MM_WEIGHT_REGISTER
[
"Default"
](
"time_embedding.2.weight"
,
"time_embedding.2.bias"
)
self
.
time_projection_1
=
MM_WEIGHT_REGISTER
[
"Default"
](
"time_projection.1.weight"
,
"time_projection.1.bias"
)
if
config
.
task
==
"i2v"
:
self
.
add_module
(
"proj_0"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.0.weight"
,
"img_emb.proj.0.bias"
,
eps
=
1e-5
))
self
.
weight_list
=
[
self
.
add_module
(
"proj_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.1.weight"
,
"img_emb.proj.1.bias"
))
self
.
patch_embedding
,
self
.
add_module
(
"proj_3"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.3.weight"
,
"img_emb.proj.3.bias"
))
self
.
text_embedding_0
,
self
.
add_module
(
"proj_4"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.4.weight"
,
"img_emb.proj.4.bias"
,
eps
=
1e-5
))
self
.
text_embedding_2
,
self
.
time_embedding_0
,
self
.
time_embedding_2
,
self
.
time_projection_1
,
]
if
"img_emb.proj.0.weight"
in
weight_dict
.
keys
():
self
.
proj_0
=
LN_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.0.weight"
,
"img_emb.proj.0.bias"
,
eps
=
1e-5
)
self
.
proj_1
=
MM_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.1.weight"
,
"img_emb.proj.1.bias"
)
self
.
proj_3
=
MM_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.3.weight"
,
"img_emb.proj.3.bias"
)
self
.
proj_4
=
LN_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.4.weight"
,
"img_emb.proj.4.bias"
,
eps
=
1e-5
)
self
.
weight_list
.
append
(
self
.
proj_0
)
self
.
weight_list
.
append
(
self
.
proj_1
)
self
.
weight_list
.
append
(
self
.
proj_3
)
self
.
weight_list
.
append
(
self
.
proj_4
)
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
Conv3dWeightTemplate
)):
weight
.
set_config
(
self
.
config
[
"mm_config"
])
weight
.
load
(
weight_dict
)
if
self
.
config
[
"cpu_offload"
]:
weight
.
to_cpu
()
def
to_cpu
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
Conv3dWeightTemplate
)):
weight
.
to_cpu
()
def
to_cuda
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
Conv3dWeightTemplate
)):
weight
.
to_cuda
()
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
ae96fdbf
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
RMS_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
RMS_WEIGHT_REGISTER
,
TENSOR_REGISTER
from
lightx2v.common.ops.mm.mm_weight
import
MMWeightTemplate
from
lightx2v.common.modules.weight_module
import
WeightModule
,
WeightModuleList
from
lightx2v.common.ops.norm.layer_norm_weight
import
LNWeightTemplate
from
lightx2v.common.ops.norm.rms_norm_weight
import
RMSWeightTemplate
class
WanTransformerWeights
:
class
WanTransformerWeights
(
WeightModule
)
:
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
blocks_num
=
config
[
"num_layers"
]
self
.
blocks_num
=
config
[
"num_layers"
]
self
.
task
=
config
[
"task"
]
self
.
task
=
config
[
"task"
]
self
.
config
=
config
self
.
config
=
config
...
@@ -13,99 +12,39 @@ class WanTransformerWeights:
...
@@ -13,99 +12,39 @@ class WanTransformerWeights:
self
.
mm_type
=
"Calib"
self
.
mm_type
=
"Calib"
else
:
else
:
self
.
mm_type
=
config
[
"mm_config"
].
get
(
"mm_type"
,
"Default"
)
if
config
[
"mm_config"
]
else
"Default"
self
.
mm_type
=
config
[
"mm_config"
].
get
(
"mm_type"
,
"Default"
)
if
config
[
"mm_config"
]
else
"Default"
self
.
blocks
=
WeightModuleList
([
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
)
for
i
in
range
(
self
.
blocks_num
)])
self
.
add_module
(
"blocks"
,
self
.
blocks
)
def
load_weights
(
self
,
weight_dict
):
self
.
blocks_weights
=
[
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
)
for
i
in
range
(
self
.
blocks_num
)]
for
block
in
self
.
blocks_weights
:
block
.
load_weights
(
weight_dict
)
def
to_cpu
(
self
):
class
WanTransformerAttentionBlock
(
WeightModule
):
for
block
in
self
.
blocks_weights
:
block
.
to_cpu
()
def
to_cuda
(
self
):
for
block
in
self
.
blocks_weights
:
block
.
to_cuda
()
class
WanTransformerAttentionBlock
:
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
mm_type
=
mm_type
self
.
task
=
task
self
.
task
=
task
self
.
config
=
config
self
.
config
=
config
def
load_weights
(
self
,
weight_dict
):
self
.
add_module
(
"self_attn_q"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.q.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.q.bias"
))
self
.
self_attn_q
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.q.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.q.bias"
)
self
.
add_module
(
"self_attn_k"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.k.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.k.bias"
))
self
.
self_attn_k
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.k.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.k.bias"
)
self
.
add_module
(
"self_attn_v"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.v.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.v.bias"
))
self
.
self_attn_v
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.v.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.v.bias"
)
self
.
add_module
(
"self_attn_o"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.o.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.o.bias"
))
self
.
self_attn_o
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.self_attn.o.weight"
,
f
"blocks.
{
self
.
block_index
}
.self_attn.o.bias"
)
self
.
add_module
(
"self_attn_norm_q"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.self_attn.norm_q.weight"
))
self
.
self_attn_norm_q
=
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.self_attn.norm_q.weight"
)
self
.
add_module
(
"self_attn_norm_k"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.self_attn.norm_k.weight"
))
self
.
self_attn_norm_k
=
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.self_attn.norm_k.weight"
)
self
.
norm3
=
LN_WEIGHT_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.norm3.weight"
,
f
"blocks.
{
self
.
block_index
}
.norm3.bias"
,
eps
=
1e-6
)
self
.
cross_attn_q
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.q.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.q.bias"
)
self
.
cross_attn_k
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.k.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.k.bias"
)
self
.
cross_attn_v
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.v.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.v.bias"
)
self
.
cross_attn_o
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.o.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.o.bias"
)
self
.
cross_attn_norm_q
=
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.norm_q.weight"
)
self
.
cross_attn_norm_k
=
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.norm_k.weight"
)
self
.
ffn_0
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.ffn.0.weight"
,
f
"blocks.
{
self
.
block_index
}
.ffn.0.bias"
)
self
.
add_module
(
"norm3"
,
LN_WEIGHT_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.norm3.weight"
,
f
"blocks.
{
self
.
block_index
}
.norm3.bias"
,
eps
=
1e-6
))
self
.
ffn_2
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.ffn.2.weight"
,
f
"blocks.
{
self
.
block_index
}
.ffn.2.bias"
)
self
.
add_module
(
"cross_attn_q"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.q.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.q.bias"
))
self
.
modulation
=
weight_dict
[
f
"blocks.
{
self
.
block_index
}
.modulation"
]
self
.
add_module
(
"cross_attn_k"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.k.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.k.bias"
))
self
.
add_module
(
"cross_attn_v"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.v.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.v.bias"
))
self
.
add_module
(
"cross_attn_o"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.o.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.o.bias"
))
self
.
add_module
(
"cross_attn_norm_q"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.norm_q.weight"
))
self
.
add_module
(
"cross_attn_norm_k"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.norm_k.weight"
))
self
.
weight_list
=
[
self
.
add_module
(
"ffn_0"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.ffn.0.weight"
,
f
"blocks.
{
self
.
block_index
}
.ffn.0.bias"
))
self
.
self_attn_q
,
self
.
add_module
(
"ffn_2"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.ffn.2.weight"
,
f
"blocks.
{
self
.
block_index
}
.ffn.2.bias"
))
self
.
self_attn_k
,
self
.
self_attn_v
,
self
.
self_attn_o
,
self
.
self_attn_norm_q
,
self
.
self_attn_norm_k
,
self
.
norm3
,
self
.
cross_attn_q
,
self
.
cross_attn_k
,
self
.
cross_attn_v
,
self
.
cross_attn_o
,
self
.
cross_attn_norm_q
,
self
.
cross_attn_norm_k
,
self
.
ffn_0
,
self
.
ffn_2
,
]
if
self
.
task
==
"i2v"
:
if
self
.
task
==
"i2v"
:
self
.
cross_attn_k_img
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.k_img.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.k_img.bias"
)
self
.
add_module
(
"cross_attn_k_img"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.k_img.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.k_img.bias"
))
self
.
cross_attn_v_img
=
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.v_img.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.v_img.bias"
)
self
.
add_module
(
"cross_attn_v_img"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.v_img.weight"
,
f
"blocks.
{
self
.
block_index
}
.cross_attn.v_img.bias"
))
self
.
cross_attn_norm_k_img
=
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.norm_k_img.weight"
)
self
.
add_module
(
"cross_attn_norm_k_img"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"blocks.
{
self
.
block_index
}
.cross_attn.norm_k_img.weight"
))
self
.
weight_list
.
append
(
self
.
cross_attn_k_img
)
self
.
weight_list
.
append
(
self
.
cross_attn_v_img
)
self
.
weight_list
.
append
(
self
.
cross_attn_norm_k_img
)
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
RMSWeightTemplate
)):
weight
.
set_config
(
self
.
config
[
"mm_config"
])
weight
.
load
(
weight_dict
)
def
to_cpu
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
RMSWeightTemplate
)):
weight
.
to_cpu
()
self
.
modulation
=
self
.
modulation
.
cpu
()
def
to_cuda
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
RMSWeightTemplate
)):
weight
.
to_cuda
()
self
.
modulation
=
self
.
modulation
.
cuda
()
def
to_cpu_sync
(
self
):
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
RMSWeightTemplate
)):
weight
.
to_cpu
(
non_blocking
=
True
)
self
.
modulation
=
self
.
modulation
.
to
(
"cpu"
,
non_blocking
=
True
)
def
to_cuda_sync
(
self
):
self
.
register_parameter
(
"modulation"
,
TENSOR_REGISTER
[
"Default"
](
f
"blocks.
{
self
.
block_index
}
.modulation"
))
for
weight
in
self
.
weight_list
:
if
isinstance
(
weight
,
(
MMWeightTemplate
,
LNWeightTemplate
,
RMSWeightTemplate
)):
weight
.
to_cuda
(
non_blocking
=
True
)
self
.
modulation
=
self
.
modulation
.
cuda
(
non_blocking
=
True
)
lightx2v/utils/registry_factory.py
View file @
ae96fdbf
...
@@ -50,4 +50,6 @@ LN_WEIGHT_REGISTER = Register()
...
@@ -50,4 +50,6 @@ LN_WEIGHT_REGISTER = Register()
CONV3D_WEIGHT_REGISTER
=
Register
()
CONV3D_WEIGHT_REGISTER
=
Register
()
CONV2D_WEIGHT_REGISTER
=
Register
()
CONV2D_WEIGHT_REGISTER
=
Register
()
TENSOR_REGISTER
=
Register
()
RUNNER_REGISTER
=
Register
()
RUNNER_REGISTER
=
Register
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment