Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
74eeb429
Unverified
Commit
74eeb429
authored
Dec 03, 2025
by
Gu Shiqiao
Committed by
GitHub
Dec 03, 2025
Browse files
reconstruct disk offload and fix lightx2v_platform bugs (#558)
Co-authored-by:
helloyongyang
<
yongyang1030@163.com
>
parent
f7cdbcb5
Changes
46
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
391 additions
and
380 deletions
+391
-380
lightx2v/models/networks/qwen_image/weights/post_weights.py
lightx2v/models/networks/qwen_image/weights/post_weights.py
+10
-0
lightx2v/models/networks/qwen_image/weights/pre_weights.py
lightx2v/models/networks/qwen_image/weights/pre_weights.py
+10
-0
lightx2v/models/networks/qwen_image/weights/transformer_weights.py
...models/networks/qwen_image/weights/transformer_weights.py
+46
-31
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
...s/networks/wan/infer/feature_caching/transformer_infer.py
+8
-5
lightx2v/models/networks/wan/infer/offload/transformer_infer.py
...2v/models/networks/wan/infer/offload/transformer_infer.py
+24
-105
lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py
lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py
+2
-1
lightx2v/models/networks/wan/infer/vace/transformer_infer.py
lightx2v/models/networks/wan/infer/vace/transformer_infer.py
+3
-2
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+21
-30
lightx2v/models/networks/wan/vace_model.py
lightx2v/models/networks/wan/vace_model.py
+9
-4
lightx2v/models/networks/wan/weights/animate/transformer_weights.py
...odels/networks/wan/weights/animate/transformer_weights.py
+19
-10
lightx2v/models/networks/wan/weights/audio/transformer_weights.py
.../models/networks/wan/weights/audio/transformer_weights.py
+47
-12
lightx2v/models/networks/wan/weights/matrix_game2/transformer_weights.py
.../networks/wan/weights/matrix_game2/transformer_weights.py
+6
-77
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+154
-68
lightx2v/models/networks/wan/weights/vace/transformer_weights.py
...v/models/networks/wan/weights/vace/transformer_weights.py
+12
-16
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+3
-6
lightx2v/models/runners/wan/wan_animate_runner.py
lightx2v/models/runners/wan/wan_animate_runner.py
+6
-5
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+1
-1
lightx2v/models/runners/wan/wan_matrix_game2_runner.py
lightx2v/models/runners/wan/wan_matrix_game2_runner.py
+6
-5
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+2
-2
lightx2v/models/schedulers/hunyuan_video/posemb_layers.py
lightx2v/models/schedulers/hunyuan_video/posemb_layers.py
+2
-0
No files found.
lightx2v/models/networks/qwen_image/weights/post_weights.py
View file @
74eeb429
...
@@ -37,3 +37,13 @@ class QwenImagePostWeights(WeightModule):
...
@@ -37,3 +37,13 @@ class QwenImagePostWeights(WeightModule):
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
)
)
def
to_cpu
(
self
,
non_blocking
=
True
):
for
module
in
self
.
_modules
.
values
():
if
module
is
not
None
and
hasattr
(
module
,
"to_cpu"
):
module
.
to_cpu
(
non_blocking
=
non_blocking
)
def
to_cuda
(
self
,
non_blocking
=
True
):
for
module
in
self
.
_modules
.
values
():
if
module
is
not
None
and
hasattr
(
module
,
"to_cuda"
):
module
.
to_cuda
(
non_blocking
=
non_blocking
)
lightx2v/models/networks/qwen_image/weights/pre_weights.py
100644 → 100755
View file @
74eeb429
...
@@ -28,3 +28,13 @@ class QwenImagePreWeights(WeightModule):
...
@@ -28,3 +28,13 @@ class QwenImagePreWeights(WeightModule):
self
.
add_module
(
self
.
add_module
(
"time_text_embed_timestep_embedder_linear_2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"time_text_embed.timestep_embedder.linear_2.weight"
,
"time_text_embed.timestep_embedder.linear_2.bias"
)
"time_text_embed_timestep_embedder_linear_2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"time_text_embed.timestep_embedder.linear_2.weight"
,
"time_text_embed.timestep_embedder.linear_2.bias"
)
)
)
def
to_cpu
(
self
,
non_blocking
=
True
):
for
module
in
self
.
_modules
.
values
():
if
module
is
not
None
and
hasattr
(
module
,
"to_cpu"
):
module
.
to_cpu
(
non_blocking
=
non_blocking
)
def
to_cuda
(
self
,
non_blocking
=
True
):
for
module
in
self
.
_modules
.
values
():
if
module
is
not
None
and
hasattr
(
module
,
"to_cuda"
):
module
.
to_cuda
(
non_blocking
=
non_blocking
)
lightx2v/models/networks/qwen_image/weights/transformer_weights.py
View file @
74eeb429
...
@@ -15,7 +15,7 @@ class QwenImageTransformerWeights(WeightModule):
...
@@ -15,7 +15,7 @@ class QwenImageTransformerWeights(WeightModule):
self
.
mm_type
=
config
.
get
(
"dit_quant_scheme"
,
"Default"
)
self
.
mm_type
=
config
.
get
(
"dit_quant_scheme"
,
"Default"
)
if
self
.
mm_type
!=
"Default"
:
if
self
.
mm_type
!=
"Default"
:
assert
config
.
get
(
"dit_quantized"
)
is
True
assert
config
.
get
(
"dit_quantized"
)
is
True
blocks
=
WeightModuleList
(
QwenImageTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
False
,
"transformer_blocks"
)
for
i
in
range
(
self
.
blocks_num
))
blocks
=
WeightModuleList
(
QwenImageTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
False
,
False
,
"transformer_blocks"
)
for
i
in
range
(
self
.
blocks_num
))
self
.
register_offload_buffers
(
config
)
self
.
register_offload_buffers
(
config
)
self
.
add_module
(
"blocks"
,
blocks
)
self
.
add_module
(
"blocks"
,
blocks
)
...
@@ -23,17 +23,17 @@ class QwenImageTransformerWeights(WeightModule):
...
@@ -23,17 +23,17 @@ class QwenImageTransformerWeights(WeightModule):
if
config
[
"cpu_offload"
]:
if
config
[
"cpu_offload"
]:
if
config
[
"offload_granularity"
]
==
"block"
:
if
config
[
"offload_granularity"
]
==
"block"
:
self
.
offload_blocks_num
=
2
self
.
offload_blocks_num
=
2
self
.
offload_block_buffers
=
WeightModuleList
(
self
.
offload_block_
cuda_
buffers
=
WeightModuleList
(
[
QwenImageTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
True
,
"transformer_blocks"
)
for
i
in
range
(
self
.
offload_blocks_num
)]
[
QwenImageTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
True
,
False
,
"transformer_blocks"
)
for
i
in
range
(
self
.
offload_blocks_num
)]
)
)
self
.
add_module
(
"offload_block_buffers"
,
self
.
offload_block_buffers
)
self
.
add_module
(
"offload_block_
cuda_
buffers"
,
self
.
offload_block_
cuda_
buffers
)
self
.
offload_phase_buffers
=
None
self
.
offload_phase_
cuda_
buffers
=
None
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
class
QwenImageTransformerAttentionBlock
(
WeightModule
):
class
QwenImageTransformerAttentionBlock
(
WeightModule
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
,
is_offload
_buffer
=
False
,
block_prefix
=
"transformer_blocks"
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
,
create_cuda_buffer
=
False
,
create_cpu
_buffer
=
False
,
block_prefix
=
"transformer_blocks"
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
mm_type
=
mm_type
...
@@ -55,14 +55,15 @@ class QwenImageTransformerAttentionBlock(WeightModule):
...
@@ -55,14 +55,15 @@ class QwenImageTransformerAttentionBlock(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mod.1.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mod.1.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mod.1.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mod.1.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
)
)
self
.
add_module
(
self
.
add_module
(
"img_norm1"
,
"img_norm1"
,
LN_WEIGHT_REGISTER
[
"Default"
](
create_cuda_buffer
=
is_offload
_buffer
,
eps
=
1e-6
),
LN_WEIGHT_REGISTER
[
"Default"
](
create_cuda_buffer
=
create_cuda_buffer
,
create_cpu_buffer
=
create_cpu
_buffer
,
eps
=
1e-6
),
)
)
self
.
attn
=
QwenImageCrossAttention
(
self
.
attn
=
QwenImageCrossAttention
(
block_index
=
block_index
,
block_index
=
block_index
,
...
@@ -70,7 +71,8 @@ class QwenImageTransformerAttentionBlock(WeightModule):
...
@@ -70,7 +71,8 @@ class QwenImageTransformerAttentionBlock(WeightModule):
task
=
config
[
"task"
],
task
=
config
[
"task"
],
mm_type
=
mm_type
,
mm_type
=
mm_type
,
config
=
config
,
config
=
config
,
is_offload_buffer
=
is_offload_buffer
,
create_cuda_buffer
=
create_cuda_buffer
,
create_cpu_buffer
=
create_cpu_buffer
,
lazy_load
=
self
.
lazy_load
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
lazy_load_file
=
self
.
lazy_load_file
,
)
)
...
@@ -78,7 +80,7 @@ class QwenImageTransformerAttentionBlock(WeightModule):
...
@@ -78,7 +80,7 @@ class QwenImageTransformerAttentionBlock(WeightModule):
self
.
add_module
(
self
.
add_module
(
"img_norm2"
,
"img_norm2"
,
LN_WEIGHT_REGISTER
[
"Default"
](
create_cuda_buffer
=
is_offload
_buffer
,
eps
=
1e-6
),
LN_WEIGHT_REGISTER
[
"Default"
](
create_cuda_buffer
=
create_cuda_buffer
,
create_cpu_buffer
=
create_cpu
_buffer
,
eps
=
1e-6
),
)
)
img_mlp
=
QwenImageFFN
(
img_mlp
=
QwenImageFFN
(
block_index
=
block_index
,
block_index
=
block_index
,
...
@@ -87,7 +89,8 @@ class QwenImageTransformerAttentionBlock(WeightModule):
...
@@ -87,7 +89,8 @@ class QwenImageTransformerAttentionBlock(WeightModule):
task
=
config
[
"task"
],
task
=
config
[
"task"
],
mm_type
=
mm_type
,
mm_type
=
mm_type
,
config
=
config
,
config
=
config
,
is_offload_buffer
=
is_offload_buffer
,
create_cuda_buffer
=
create_cuda_buffer
,
create_cpu_buffer
=
create_cpu_buffer
,
lazy_load
=
self
.
lazy_load
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
lazy_load_file
=
self
.
lazy_load_file
,
)
)
...
@@ -99,20 +102,21 @@ class QwenImageTransformerAttentionBlock(WeightModule):
...
@@ -99,20 +102,21 @@ class QwenImageTransformerAttentionBlock(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mod.1.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mod.1.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mod.1.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mod.1.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
)
)
self
.
add_module
(
self
.
add_module
(
"txt_norm1"
,
"txt_norm1"
,
LN_WEIGHT_REGISTER
[
"Default"
](
create_cuda_buffer
=
is_offload
_buffer
,
eps
=
1e-6
),
LN_WEIGHT_REGISTER
[
"Default"
](
create_cuda_buffer
=
create_cuda_buffer
,
create_cpu_buffer
=
create_cpu
_buffer
,
eps
=
1e-6
),
)
)
# Text doesn't need separate attention - it's handled by img_attn joint computation
# Text doesn't need separate attention - it's handled by img_attn joint computation
self
.
add_module
(
self
.
add_module
(
"txt_norm2"
,
"txt_norm2"
,
LN_WEIGHT_REGISTER
[
"Default"
](
create_cuda_buffer
=
is_offload
_buffer
,
eps
=
1e-6
),
LN_WEIGHT_REGISTER
[
"Default"
](
create_cuda_buffer
=
create_cuda_buffer
,
create_cpu_buffer
=
create_cpu
_buffer
,
eps
=
1e-6
),
)
)
txt_mlp
=
QwenImageFFN
(
txt_mlp
=
QwenImageFFN
(
block_index
=
block_index
,
block_index
=
block_index
,
...
@@ -121,7 +125,8 @@ class QwenImageTransformerAttentionBlock(WeightModule):
...
@@ -121,7 +125,8 @@ class QwenImageTransformerAttentionBlock(WeightModule):
task
=
config
[
"task"
],
task
=
config
[
"task"
],
mm_type
=
mm_type
,
mm_type
=
mm_type
,
config
=
config
,
config
=
config
,
is_offload_buffer
=
is_offload_buffer
,
create_cuda_buffer
=
create_cuda_buffer
,
create_cpu_buffer
=
create_cpu_buffer
,
lazy_load
=
self
.
lazy_load
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
lazy_load_file
=
self
.
lazy_load_file
,
)
)
...
@@ -129,7 +134,7 @@ class QwenImageTransformerAttentionBlock(WeightModule):
...
@@ -129,7 +134,7 @@ class QwenImageTransformerAttentionBlock(WeightModule):
class
QwenImageCrossAttention
(
WeightModule
):
class
QwenImageCrossAttention
(
WeightModule
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
is_offload
_buffer
,
lazy_load
,
lazy_load_file
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
create_cuda_buffer
,
create_cpu
_buffer
,
lazy_load
,
lazy_load_file
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
mm_type
=
mm_type
...
@@ -146,12 +151,12 @@ class QwenImageCrossAttention(WeightModule):
...
@@ -146,12 +151,12 @@ class QwenImageCrossAttention(WeightModule):
# norm_q
# norm_q
self
.
add_module
(
self
.
add_module
(
"norm_q"
,
"norm_q"
,
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_q.weight"
,
create_cuda_buffer
=
is_offload
_buffer
),
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_q.weight"
,
create_cuda_buffer
=
create_cuda_buffer
,
create_cpu_buffer
=
create_cpu
_buffer
),
)
)
# norm_k
# norm_k
self
.
add_module
(
self
.
add_module
(
"norm_k"
,
"norm_k"
,
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_k.weight"
,
create_cuda_buffer
=
is_offload
_buffer
),
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_k.weight"
,
create_cuda_buffer
=
create_cuda_buffer
,
create_cpu_buffer
=
create_cpu
_buffer
),
)
)
# to_q
# to_q
self
.
add_module
(
self
.
add_module
(
...
@@ -159,7 +164,8 @@ class QwenImageCrossAttention(WeightModule):
...
@@ -159,7 +164,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_q.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_q.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -170,7 +176,8 @@ class QwenImageCrossAttention(WeightModule):
...
@@ -170,7 +176,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_k.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_k.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -181,7 +188,8 @@ class QwenImageCrossAttention(WeightModule):
...
@@ -181,7 +188,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_v.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_v.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -192,7 +200,8 @@ class QwenImageCrossAttention(WeightModule):
...
@@ -192,7 +200,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_q_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_q_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_q_proj.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_q_proj.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -203,7 +212,8 @@ class QwenImageCrossAttention(WeightModule):
...
@@ -203,7 +212,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_k_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_k_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_k_proj.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_k_proj.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -214,7 +224,8 @@ class QwenImageCrossAttention(WeightModule):
...
@@ -214,7 +224,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_v_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_v_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_v_proj.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_v_proj.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -225,7 +236,8 @@ class QwenImageCrossAttention(WeightModule):
...
@@ -225,7 +236,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_out.0.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_out.0.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_out.0.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_out.0.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -236,7 +248,8 @@ class QwenImageCrossAttention(WeightModule):
...
@@ -236,7 +248,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_add_out.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_add_out.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_add_out.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_add_out.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -244,12 +257,12 @@ class QwenImageCrossAttention(WeightModule):
...
@@ -244,12 +257,12 @@ class QwenImageCrossAttention(WeightModule):
# norm_added_q
# norm_added_q
self
.
add_module
(
self
.
add_module
(
"norm_added_q"
,
"norm_added_q"
,
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_added_q.weight"
,
create_cuda_buffer
=
is_offload
_buffer
),
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_added_q.weight"
,
create_cuda_buffer
=
create_cuda_buffer
,
create_cpu_buffer
=
create_cpu
_buffer
),
)
)
# norm_added_k
# norm_added_k
self
.
add_module
(
self
.
add_module
(
"norm_added_k"
,
"norm_added_k"
,
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_added_k.weight"
,
create_cuda_buffer
=
is_offload
_buffer
),
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_added_k.weight"
,
create_cuda_buffer
=
create_cuda_buffer
,
create_cpu_buffer
=
create_cpu
_buffer
),
)
)
# attn
# attn
self
.
add_module
(
"calculate"
,
ATTN_WEIGHT_REGISTER
[
self
.
attn_type
]())
self
.
add_module
(
"calculate"
,
ATTN_WEIGHT_REGISTER
[
self
.
attn_type
]())
...
@@ -266,7 +279,7 @@ class QwenImageCrossAttention(WeightModule):
...
@@ -266,7 +279,7 @@ class QwenImageCrossAttention(WeightModule):
class
QwenImageFFN
(
WeightModule
):
class
QwenImageFFN
(
WeightModule
):
def
__init__
(
self
,
block_index
,
block_prefix
,
ffn_prefix
,
task
,
mm_type
,
config
,
is_offload
_buffer
,
lazy_load
,
lazy_load_file
):
def
__init__
(
self
,
block_index
,
block_prefix
,
ffn_prefix
,
task
,
mm_type
,
config
,
create_cuda_buffer
,
create_cpu
_buffer
,
lazy_load
,
lazy_load_file
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
mm_type
=
mm_type
...
@@ -281,7 +294,8 @@ class QwenImageFFN(WeightModule):
...
@@ -281,7 +294,8 @@ class QwenImageFFN(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.
{
ffn_prefix
}
.net.0.proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.
{
ffn_prefix
}
.net.0.proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.
{
ffn_prefix
}
.net.0.proj.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.
{
ffn_prefix
}
.net.0.proj.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -291,7 +305,8 @@ class QwenImageFFN(WeightModule):
...
@@ -291,7 +305,8 @@ class QwenImageFFN(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.
{
ffn_prefix
}
.net.2.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.
{
ffn_prefix
}
.net.2.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.
{
ffn_prefix
}
.net.2.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.
{
ffn_prefix
}
.net.2.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
...
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
View file @
74eeb429
...
@@ -7,6 +7,7 @@ import torch.nn.functional as F
...
@@ -7,6 +7,7 @@ import torch.nn.functional as F
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTaylorCachingTransformerInfer
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTaylorCachingTransformerInfer
from
lightx2v.models.networks.wan.infer.offload.transformer_infer
import
WanOffloadTransformerInfer
from
lightx2v.models.networks.wan.infer.offload.transformer_infer
import
WanOffloadTransformerInfer
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
WanTransformerInferCaching
(
WanOffloadTransformerInfer
):
class
WanTransformerInferCaching
(
WanOffloadTransformerInfer
):
...
@@ -56,7 +57,9 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
...
@@ -56,7 +57,9 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
self
.
accumulated_rel_l1_distance_even
=
0
self
.
accumulated_rel_l1_distance_even
=
0
else
:
else
:
rescale_func
=
np
.
poly1d
(
self
.
coefficients
)
rescale_func
=
np
.
poly1d
(
self
.
coefficients
)
self
.
accumulated_rel_l1_distance_even
+=
rescale_func
(((
modulated_inp
-
self
.
previous_e0_even
.
cuda
()).
abs
().
mean
()
/
self
.
previous_e0_even
.
cuda
().
abs
().
mean
()).
cpu
().
item
())
self
.
accumulated_rel_l1_distance_even
+=
rescale_func
(
((
modulated_inp
-
self
.
previous_e0_even
.
to
(
AI_DEVICE
)).
abs
().
mean
()
/
self
.
previous_e0_even
.
to
(
AI_DEVICE
).
abs
().
mean
()).
cpu
().
item
()
)
if
self
.
accumulated_rel_l1_distance_even
<
self
.
teacache_thresh
:
if
self
.
accumulated_rel_l1_distance_even
<
self
.
teacache_thresh
:
should_calc
=
False
should_calc
=
False
else
:
else
:
...
@@ -72,7 +75,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
...
@@ -72,7 +75,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
self
.
accumulated_rel_l1_distance_odd
=
0
self
.
accumulated_rel_l1_distance_odd
=
0
else
:
else
:
rescale_func
=
np
.
poly1d
(
self
.
coefficients
)
rescale_func
=
np
.
poly1d
(
self
.
coefficients
)
self
.
accumulated_rel_l1_distance_odd
+=
rescale_func
(((
modulated_inp
-
self
.
previous_e0_odd
.
cuda
(
)).
abs
().
mean
()
/
self
.
previous_e0_odd
.
cuda
(
).
abs
().
mean
()).
cpu
().
item
())
self
.
accumulated_rel_l1_distance_odd
+=
rescale_func
(((
modulated_inp
-
self
.
previous_e0_odd
.
to
(
AI_DEVICE
)).
abs
().
mean
()
/
self
.
previous_e0_odd
.
to
(
AI_DEVICE
).
abs
().
mean
()).
cpu
().
item
())
if
self
.
accumulated_rel_l1_distance_odd
<
self
.
teacache_thresh
:
if
self
.
accumulated_rel_l1_distance_odd
<
self
.
teacache_thresh
:
should_calc
=
False
should_calc
=
False
else
:
else
:
...
@@ -149,9 +152,9 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
...
@@ -149,9 +152,9 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
def
infer_using_cache
(
self
,
x
):
def
infer_using_cache
(
self
,
x
):
if
self
.
scheduler
.
infer_condition
:
if
self
.
scheduler
.
infer_condition
:
x
.
add_
(
self
.
previous_residual_even
.
cuda
(
))
x
.
add_
(
self
.
previous_residual_even
.
to
(
AI_DEVICE
))
else
:
else
:
x
.
add_
(
self
.
previous_residual_odd
.
cuda
(
))
x
.
add_
(
self
.
previous_residual_odd
.
to
(
AI_DEVICE
))
return
x
return
x
def
clear
(
self
):
def
clear
(
self
):
...
@@ -1075,7 +1078,7 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching):
...
@@ -1075,7 +1078,7 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching):
def
infer_using_cache
(
self
,
x
):
def
infer_using_cache
(
self
,
x
):
residual_x
=
self
.
residual_cache
[
self
.
scheduler
.
infer_condition
]
residual_x
=
self
.
residual_cache
[
self
.
scheduler
.
infer_condition
]
x
.
add_
(
residual_x
.
cuda
(
))
x
.
add_
(
residual_x
.
to
(
AI_DEVICE
))
return
x
return
x
def
clear
(
self
):
def
clear
(
self
):
...
...
lightx2v/models/networks/wan/infer/offload/transformer_infer.py
View file @
74eeb429
import
torch
import
torch
from
lightx2v.common.offload.manager
import
(
from
lightx2v.common.offload.manager
import
WeightAsyncStreamManager
LazyWeightAsyncStreamManager
,
WeightAsyncStreamManager
,
)
from
lightx2v.models.networks.wan.infer.transformer_infer
import
WanTransformerInfer
from
lightx2v.models.networks.wan.infer.transformer_infer
import
WanTransformerInfer
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
class
WanOffloadTransformerInfer
(
WanTransformerInfer
):
class
WanOffloadTransformerInfer
(
WanTransformerInfer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
if
"offload_ratio"
in
self
.
config
:
self
.
offload_ratio
=
self
.
config
[
"offload_ratio"
]
else
:
self
.
offload_ratio
=
1
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
if
offload_granularity
==
"block"
:
if
offload_granularity
==
"block"
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
infer_func
=
self
.
infer_with_blocks_offload
self
.
infer_func
=
self
.
infer_with_blocks_offload
else
:
self
.
infer_func
=
self
.
infer_with_blocks_lazy_offload
elif
offload_granularity
==
"phase"
:
elif
offload_granularity
==
"phase"
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
infer_func
=
self
.
infer_with_phases_offload
self
.
infer_func
=
self
.
infer_with_phases_offload
else
:
self
.
infer_func
=
self
.
infer_with_phases_lazy_offload
self
.
phase_params
=
{
self
.
phase_params
=
{
"shift_msa"
:
None
,
"shift_msa"
:
None
,
"scale_msa"
:
None
,
"scale_msa"
:
None
,
...
@@ -41,121 +31,54 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
...
@@ -41,121 +31,54 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
self
.
infer_func
=
self
.
infer_without_offload
self
.
infer_func
=
self
.
infer_without_offload
if
offload_granularity
!=
"model"
:
if
offload_granularity
!=
"model"
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
offload_manager
=
WeightAsyncStreamManager
(
offload_granularity
=
offload_granularity
)
self
.
offload_manager
=
WeightAsyncStreamManager
(
offload_granularity
=
offload_granularity
)
else
:
self
.
offload_manager
=
LazyWeightAsyncStreamManager
(
blocks_num
=
self
.
blocks_num
,
offload_ratio
=
self
.
offload_ratio
,
phases_num
=
self
.
phases_num
,
num_disk_workers
=
self
.
config
.
get
(
"num_disk_workers"
,
2
),
max_memory
=
self
.
config
.
get
(
"max_memory"
,
2
),
offload_gra
=
offload_granularity
,
)
def
infer_with_blocks_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
def
infer_with_blocks_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
for
block_idx
in
range
(
len
(
blocks
)):
for
block_idx
in
range
(
len
(
blocks
)):
self
.
block_idx
=
block_idx
self
.
block_idx
=
block_idx
if
block_idx
==
0
:
if
self
.
offload_manager
.
need_init_first_buffer
:
self
.
offload_manager
.
init_first_buffer
(
blocks
)
self
.
offload_manager
.
init_first_buffer
(
blocks
)
if
block_idx
<
len
(
blocks
)
-
1
:
self
.
offload_manager
.
prefetch_weights
((
block_idx
+
1
)
%
len
(
blocks
),
blocks
)
self
.
offload_manager
.
prefetch_weights
(
block_idx
+
1
,
blocks
)
with
torch_device_module
.
stream
(
self
.
offload_manager
.
compute_stream
):
with
torch
.
cuda
.
stream
(
self
.
offload_manager
.
compute_stream
):
x
=
self
.
infer_block
(
self
.
offload_manager
.
cuda_buffers
[
0
],
x
,
pre_infer_out
)
x
=
self
.
infer_block
(
self
.
offload_manager
.
cuda_buffers
[
0
],
x
,
pre_infer_out
)
self
.
offload_manager
.
swap_blocks
()
self
.
offload_manager
.
swap_blocks
()
return
x
def
infer_with_phases_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
for
block_idx
in
range
(
len
(
blocks
)):
self
.
block_idx
=
block_idx
x
=
self
.
infer_phases
(
block_idx
,
blocks
,
x
,
pre_infer_out
,
False
)
if
self
.
clean_cuda_cache
:
del
(
self
.
phase_params
[
"attn_out"
],
self
.
phase_params
[
"y_out"
],
self
.
phase_params
[
"y"
],
)
torch
.
cuda
.
empty_cache
()
if
self
.
clean_cuda_cache
:
self
.
clear_offload_params
(
pre_infer_out
)
return
x
def
infer_with_blocks_lazy_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
self
.
offload_manager
.
prefetch_weights_from_disk
(
blocks
)
for
block_idx
in
range
(
len
(
blocks
)):
self
.
block_idx
=
block_idx
if
block_idx
==
0
:
block
=
self
.
offload_manager
.
pin_memory_buffer
.
get
(
block_idx
)
block
.
to_cuda
()
self
.
offload_manager
.
cuda_buffers
[
0
]
=
(
block_idx
,
block
)
if
block_idx
<
len
(
blocks
)
-
1
:
self
.
offload_manager
.
prefetch_weights
(
block_idx
+
1
,
blocks
)
with
torch
.
cuda
.
stream
(
self
.
offload_manager
.
compute_stream
):
x
=
self
.
infer_block
(
blocks
[
block_idx
],
x
,
pre_infer_out
)
self
.
offload_manager
.
swap_blocks
()
if
block_idx
==
len
(
blocks
)
-
1
:
self
.
offload_manager
.
pin_memory_buffer
.
pop_front
()
self
.
offload_manager
.
_async_prefetch_block
(
blocks
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
(
del
(
pre_infer_out
.
embed0
,
pre_infer_out
.
embed0
,
pre_infer_out
.
freqs
,
pre_infer_out
.
context
,
pre_infer_out
.
context
,
)
)
torch
.
cuda
.
empty_cache
()
torch
_device_module
.
empty_cache
()
return
x
return
x
def
infer_with_phases_lazy_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
def
infer_with_phases_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
self
.
offload_manager
.
prefetch_weights_from_disk
(
blocks
)
for
block_idx
in
range
(
len
(
blocks
)):
for
block_idx
in
range
(
len
(
blocks
)):
self
.
block_idx
=
block_idx
self
.
block_idx
=
block_idx
x
=
self
.
infer_phases
(
block_idx
,
blocks
,
x
,
pre_infer_out
,
True
)
x
=
self
.
infer_phases
(
block_idx
,
blocks
,
x
,
pre_infer_out
)
self
.
offload_manager
.
_async_prefetch_block
(
blocks
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
(
del
(
self
.
phase_params
[
"attn_out"
],
self
.
phase_params
[
"attn_out"
],
self
.
phase_params
[
"y_out"
],
self
.
phase_params
[
"y_out"
],
self
.
phase_params
[
"y"
],
self
.
phase_params
[
"y"
],
)
)
torch
.
cuda
.
empty_cache
()
torch_device_module
.
empty_cache
()
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
self
.
clear_offload_params
(
pre_infer_out
)
self
.
clear_offload_params
(
pre_infer_out
)
return
x
return
x
def
infer_phases
(
self
,
block_idx
,
blocks
,
x
,
pre_infer_out
,
lazy
):
def
infer_phases
(
self
,
block_idx
,
blocks
,
x
,
pre_infer_out
):
for
phase_idx
in
range
(
self
.
phases_num
):
for
phase_idx
in
range
(
self
.
phases_num
):
if
block_idx
==
0
and
phase_idx
==
0
:
if
self
.
offload_manager
.
need_init_first_buffer
:
if
lazy
:
self
.
offload_manager
.
init_first_buffer
(
blocks
)
obj_key
=
(
block_idx
,
phase_idx
)
next_block_idx
=
(
block_idx
+
1
)
%
len
(
blocks
)
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
phase
=
self
.
offload_manager
.
pin_memory_buffer
.
get
(
obj_key
)
next_phase_idx
=
(
phase_idx
+
1
)
%
self
.
phases_num
phase
.
to_cuda
()
self
.
offload_manager
.
prefetch_phase
(
next_block_idx
,
next_phase_idx
,
blocks
)
self
.
offload_manager
.
cuda_buffers
[
0
]
=
(
obj_key
,
phase
)
with
torch_device_module
.
stream
(
self
.
offload_manager
.
compute_stream
):
else
:
self
.
offload_manager
.
init_first_buffer
(
blocks
)
is_last_phase
=
block_idx
==
len
(
blocks
)
-
1
and
phase_idx
==
self
.
phases_num
-
1
if
not
is_last_phase
:
next_block_idx
=
block_idx
+
1
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
next_phase_idx
=
(
phase_idx
+
1
)
%
self
.
phases_num
self
.
offload_manager
.
prefetch_phase
(
next_block_idx
,
next_phase_idx
,
blocks
)
with
torch
.
cuda
.
stream
(
self
.
offload_manager
.
compute_stream
):
x
=
self
.
infer_phase
(
phase_idx
,
self
.
offload_manager
.
cuda_buffers
[
phase_idx
],
x
,
pre_infer_out
)
x
=
self
.
infer_phase
(
phase_idx
,
self
.
offload_manager
.
cuda_buffers
[
phase_idx
],
x
,
pre_infer_out
)
self
.
offload_manager
.
swap_phases
()
self
.
offload_manager
.
swap_phases
()
...
@@ -176,10 +99,7 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
...
@@ -176,10 +99,7 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
)
=
self
.
pre_process
(
cur_phase
.
modulation
,
pre_infer_out
.
embed0
)
)
=
self
.
pre_process
(
cur_phase
.
modulation
,
pre_infer_out
.
embed0
)
self
.
phase_params
[
"y_out"
]
=
self
.
infer_self_attn
(
self
.
phase_params
[
"y_out"
]
=
self
.
infer_self_attn
(
cur_phase
,
cur_phase
,
pre_infer_out
.
grid_sizes
.
tuple
,
x
,
x
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
self
.
phase_params
[
"shift_msa"
],
self
.
phase_params
[
"shift_msa"
],
self
.
phase_params
[
"scale_msa"
],
self
.
phase_params
[
"scale_msa"
],
)
)
...
@@ -219,7 +139,6 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
...
@@ -219,7 +139,6 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
)
)
del
(
del
(
pre_infer_out
.
embed0
,
pre_infer_out
.
embed0
,
pre_infer_out
.
freqs
,
pre_infer_out
.
context
,
pre_infer_out
.
context
,
)
)
torch
.
cuda
.
empty_cache
()
torch
_device_module
.
empty_cache
()
lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py
100644 → 100755
View file @
74eeb429
...
@@ -6,6 +6,7 @@ import torch
...
@@ -6,6 +6,7 @@ import torch
from
lightx2v.models.networks.wan.infer.module_io
import
GridOutput
from
lightx2v.models.networks.wan.infer.module_io
import
GridOutput
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v_platform.base.global_var
import
AI_DEVICE
def
sinusoidal_embedding_1d
(
dim
,
position
):
def
sinusoidal_embedding_1d
(
dim
,
position
):
...
@@ -50,7 +51,7 @@ class WanSFPreInfer(WanPreInfer):
...
@@ -50,7 +51,7 @@ class WanSFPreInfer(WanPreInfer):
rope_params
(
1024
,
2
*
(
d
//
6
)),
rope_params
(
1024
,
2
*
(
d
//
6
)),
],
],
dim
=
1
,
dim
=
1
,
).
cuda
(
)
).
to
(
AI_DEVICE
)
def
time_embedding
(
self
,
weights
,
embed
):
def
time_embedding
(
self
,
weights
,
embed
):
embed
=
weights
.
time_embedding_0
.
apply
(
embed
)
embed
=
weights
.
time_embedding_0
.
apply
(
embed
)
...
...
lightx2v/models/networks/wan/infer/vace/transformer_infer.py
View file @
74eeb429
...
@@ -9,6 +9,7 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer):
...
@@ -9,6 +9,7 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer):
self
.
vace_blocks_mapping
=
{
orig_idx
:
seq_idx
for
seq_idx
,
orig_idx
in
enumerate
(
self
.
config
[
"vace_layers"
])}
self
.
vace_blocks_mapping
=
{
orig_idx
:
seq_idx
for
seq_idx
,
orig_idx
in
enumerate
(
self
.
config
[
"vace_layers"
])}
def
infer
(
self
,
weights
,
pre_infer_out
):
def
infer
(
self
,
weights
,
pre_infer_out
):
self
.
get_scheduler_values
()
pre_infer_out
.
c
=
self
.
vace_pre_process
(
weights
.
vace_patch_embedding
,
pre_infer_out
.
vace_context
)
pre_infer_out
.
c
=
self
.
vace_pre_process
(
weights
.
vace_patch_embedding
,
pre_infer_out
.
vace_context
)
self
.
infer_vace_blocks
(
weights
.
vace_blocks
,
pre_infer_out
)
self
.
infer_vace_blocks
(
weights
.
vace_blocks
,
pre_infer_out
)
x
=
self
.
infer_main_blocks
(
weights
.
blocks
,
pre_infer_out
)
x
=
self
.
infer_main_blocks
(
weights
.
blocks
,
pre_infer_out
)
...
@@ -23,11 +24,11 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer):
...
@@ -23,11 +24,11 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer):
pre_infer_out
.
adapter_args
[
"hints"
]
=
[]
pre_infer_out
.
adapter_args
[
"hints"
]
=
[]
self
.
infer_state
=
"vace"
self
.
infer_state
=
"vace"
if
hasattr
(
self
,
"offload_manager"
):
if
hasattr
(
self
,
"offload_manager"
):
self
.
offload_manager
.
init_cuda_buffer
(
self
.
vace_offload_block_buffers
,
self
.
vace_offload_phase_buffers
)
self
.
offload_manager
.
init_cuda_buffer
(
self
.
vace_offload_block_
cuda_
buffers
,
self
.
vace_offload_phase_
cuda_
buffers
)
self
.
infer_func
(
vace_blocks
,
pre_infer_out
.
c
,
pre_infer_out
)
self
.
infer_func
(
vace_blocks
,
pre_infer_out
.
c
,
pre_infer_out
)
self
.
infer_state
=
"base"
self
.
infer_state
=
"base"
if
hasattr
(
self
,
"offload_manager"
):
if
hasattr
(
self
,
"offload_manager"
):
self
.
offload_manager
.
init_cuda_buffer
(
self
.
offload_block_buffers
,
self
.
offload_phase_buffers
)
self
.
offload_manager
.
init_cuda_buffer
(
self
.
offload_block_
cuda_
buffers
,
self
.
offload_phase_
cuda_
buffers
)
def
post_process
(
self
,
x
,
y
,
c_gate_msa
,
pre_infer_out
):
def
post_process
(
self
,
x
,
y
,
c_gate_msa
,
pre_infer_out
):
x
=
super
().
post_process
(
x
,
y
,
c_gate_msa
,
pre_infer_out
)
x
=
super
().
post_process
(
x
,
y
,
c_gate_msa
,
pre_infer_out
)
...
...
lightx2v/models/networks/wan/model.py
View file @
74eeb429
...
@@ -47,7 +47,10 @@ class WanModel(CompiledMethodsMixin):
...
@@ -47,7 +47,10 @@ class WanModel(CompiledMethodsMixin):
self
.
cpu_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
self
.
cpu_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
self
.
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
self
.
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
self
.
model_type
=
model_type
self
.
model_type
=
model_type
self
.
remove_keys
=
[]
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
if
self
.
lazy_load
:
self
.
remove_keys
.
extend
([
"blocks."
])
if
self
.
config
[
"seq_parallel"
]:
if
self
.
config
[
"seq_parallel"
]:
self
.
seq_p_group
=
self
.
config
.
get
(
"device_mesh"
).
get_group
(
mesh_dim
=
"seq_p"
)
self
.
seq_p_group
=
self
.
config
.
get
(
"device_mesh"
).
get_group
(
mesh_dim
=
"seq_p"
)
else
:
else
:
...
@@ -146,7 +149,7 @@ class WanModel(CompiledMethodsMixin):
...
@@ -146,7 +149,7 @@ class WanModel(CompiledMethodsMixin):
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
if
self
.
config
[
"parallel"
]
:
if
self
.
device
.
type
!=
"cpu"
and
dist
.
is_initialized
()
:
device
=
dist
.
get_rank
()
device
=
dist
.
get_rank
()
else
:
else
:
device
=
str
(
self
.
device
)
device
=
str
(
self
.
device
)
...
@@ -169,6 +172,10 @@ class WanModel(CompiledMethodsMixin):
...
@@ -169,6 +172,10 @@ class WanModel(CompiledMethodsMixin):
else
:
else
:
safetensors_files
=
[
safetensors_path
]
safetensors_files
=
[
safetensors_path
]
if
self
.
lazy_load
:
assert
len
(
safetensors_files
)
==
1
,
"Only support single safetensors file in lazy load mode"
self
.
lazy_load_path
=
safetensors_files
[
0
]
weight_dict
=
{}
weight_dict
=
{}
for
file_path
in
safetensors_files
:
for
file_path
in
safetensors_files
:
if
self
.
config
.
get
(
"adapter_model_path"
,
None
)
is
not
None
:
if
self
.
config
.
get
(
"adapter_model_path"
,
None
)
is
not
None
:
...
@@ -205,6 +212,10 @@ class WanModel(CompiledMethodsMixin):
...
@@ -205,6 +212,10 @@ class WanModel(CompiledMethodsMixin):
safetensors_files
=
[
safetensors_path
]
safetensors_files
=
[
safetensors_path
]
safetensors_path
=
os
.
path
.
dirname
(
safetensors_path
)
safetensors_path
=
os
.
path
.
dirname
(
safetensors_path
)
if
self
.
lazy_load
:
assert
len
(
safetensors_files
)
==
1
,
"Only support single safetensors file in lazy load mode"
self
.
lazy_load_path
=
safetensors_files
[
0
]
weight_dict
=
{}
weight_dict
=
{}
for
safetensor_path
in
safetensors_files
:
for
safetensor_path
in
safetensors_files
:
if
self
.
config
.
get
(
"adapter_model_path"
,
None
)
is
not
None
:
if
self
.
config
.
get
(
"adapter_model_path"
,
None
)
is
not
None
:
...
@@ -237,28 +248,6 @@ class WanModel(CompiledMethodsMixin):
...
@@ -237,28 +248,6 @@ class WanModel(CompiledMethodsMixin):
return
weight_dict
return
weight_dict
def
_load_quant_split_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
# Need rewrite
lazy_load_model_path
=
self
.
dit_quantized_ckpt
logger
.
info
(
f
"Loading splited quant model from
{
lazy_load_model_path
}
"
)
pre_post_weight_dict
=
{}
safetensor_path
=
os
.
path
.
join
(
lazy_load_model_path
,
"non_block.safetensors"
)
with
safe_open
(
safetensor_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
for
k
in
f
.
keys
():
if
f
.
get_tensor
(
k
).
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float
,
]:
if
unified_dtype
or
all
(
s
not
in
k
for
s
in
sensitive_layer
):
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
to
(
GET_DTYPE
()).
to
(
self
.
device
)
else
:
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
to
(
GET_SENSITIVE_DTYPE
()).
to
(
self
.
device
)
else
:
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
to
(
self
.
device
)
return
pre_post_weight_dict
def
_load_gguf_ckpt
(
self
,
gguf_path
):
def
_load_gguf_ckpt
(
self
,
gguf_path
):
state_dict
=
load_gguf_sd_ckpt
(
gguf_path
,
to_device
=
self
.
device
)
state_dict
=
load_gguf_sd_ckpt
(
gguf_path
,
to_device
=
self
.
device
)
return
state_dict
return
state_dict
...
@@ -285,10 +274,7 @@ class WanModel(CompiledMethodsMixin):
...
@@ -285,10 +274,7 @@ class WanModel(CompiledMethodsMixin):
weight_dict
=
self
.
_load_ckpt
(
unified_dtype
,
sensitive_layer
)
weight_dict
=
self
.
_load_ckpt
(
unified_dtype
,
sensitive_layer
)
else
:
else
:
# Load quantized weights
# Load quantized weights
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
weight_dict
=
self
.
_load_quant_ckpt
(
unified_dtype
,
sensitive_layer
)
weight_dict
=
self
.
_load_quant_ckpt
(
unified_dtype
,
sensitive_layer
)
else
:
weight_dict
=
self
.
_load_quant_split_ckpt
(
unified_dtype
,
sensitive_layer
)
if
self
.
config
.
get
(
"device_mesh"
)
is
not
None
and
self
.
config
.
get
(
"load_from_rank0"
,
False
):
if
self
.
config
.
get
(
"device_mesh"
)
is
not
None
and
self
.
config
.
get
(
"load_from_rank0"
,
False
):
weight_dict
=
self
.
_load_weights_from_rank0
(
weight_dict
,
is_weight_loader
)
weight_dict
=
self
.
_load_weights_from_rank0
(
weight_dict
,
is_weight_loader
)
...
@@ -302,7 +288,10 @@ class WanModel(CompiledMethodsMixin):
...
@@ -302,7 +288,10 @@ class WanModel(CompiledMethodsMixin):
# Initialize weight containers
# Initialize weight containers
self
.
pre_weight
=
self
.
pre_weight_class
(
self
.
config
)
self
.
pre_weight
=
self
.
pre_weight_class
(
self
.
config
)
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
)
if
self
.
lazy_load
:
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
,
self
.
lazy_load_path
)
else
:
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
)
if
not
self
.
_should_init_empty_model
():
if
not
self
.
_should_init_empty_model
():
self
.
_apply_weights
()
self
.
_apply_weights
()
...
@@ -383,7 +372,9 @@ class WanModel(CompiledMethodsMixin):
...
@@ -383,7 +372,9 @@ class WanModel(CompiledMethodsMixin):
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
if
hasattr
(
self
.
transformer_infer
,
"offload_manager"
):
if
hasattr
(
self
.
transformer_infer
,
"offload_manager"
):
self
.
transformer_infer
.
offload_manager
.
init_cuda_buffer
(
self
.
transformer_weights
.
offload_block_buffers
,
self
.
transformer_weights
.
offload_phase_buffers
)
self
.
transformer_infer
.
offload_manager
.
init_cuda_buffer
(
self
.
transformer_weights
.
offload_block_cuda_buffers
,
self
.
transformer_weights
.
offload_phase_cuda_buffers
)
if
self
.
lazy_load
:
self
.
transformer_infer
.
offload_manager
.
init_cpu_buffer
(
self
.
transformer_weights
.
offload_block_cpu_buffers
,
self
.
transformer_weights
.
offload_phase_cpu_buffers
)
def
set_scheduler
(
self
,
scheduler
):
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
...
...
lightx2v/models/networks/wan/vace_model.py
View file @
74eeb429
...
@@ -22,10 +22,15 @@ class WanVaceModel(WanModel):
...
@@ -22,10 +22,15 @@ class WanVaceModel(WanModel):
def
_init_infer
(
self
):
def
_init_infer
(
self
):
super
().
_init_infer
()
super
().
_init_infer
()
if
hasattr
(
self
.
transformer_infer
,
"offload_manager"
):
if
hasattr
(
self
.
transformer_infer
,
"offload_manager"
):
self
.
transformer_infer
.
offload_block_buffers
=
self
.
transformer_weights
.
offload_block_buffers
self
.
transformer_infer
.
offload_block_cuda_buffers
=
self
.
transformer_weights
.
offload_block_cuda_buffers
self
.
transformer_infer
.
offload_phase_buffers
=
self
.
transformer_weights
.
offload_phase_buffers
self
.
transformer_infer
.
offload_phase_cuda_buffers
=
self
.
transformer_weights
.
offload_phase_cuda_buffers
self
.
transformer_infer
.
vace_offload_block_buffers
=
self
.
transformer_weights
.
vace_offload_block_buffers
self
.
transformer_infer
.
vace_offload_block_cuda_buffers
=
self
.
transformer_weights
.
vace_offload_block_cuda_buffers
self
.
transformer_infer
.
vace_offload_phase_buffers
=
self
.
transformer_weights
.
vace_offload_phase_buffers
self
.
transformer_infer
.
vace_offload_phase_cuda_buffers
=
self
.
transformer_weights
.
vace_offload_phase_cuda_buffers
if
self
.
lazy_load
:
self
.
transformer_infer
.
offload_block_cpu_buffers
=
self
.
transformer_weights
.
offload_block_cpu_buffers
self
.
transformer_infer
.
offload_phase_cpu_buffers
=
self
.
transformer_weights
.
offload_phase_cpu_buffers
self
.
transformer_infer
.
vace_offload_block_cpu_buffers
=
self
.
transformer_weights
.
vace_offload_block_cpu_buffers
self
.
transformer_infer
.
vace_offload_phase_cpu_buffers
=
self
.
transformer_weights
.
vace_offload_phase_cpu_buffers
def
_init_infer_class
(
self
):
def
_init_infer_class
(
self
):
self
.
pre_infer_class
=
WanPreInfer
self
.
pre_infer_class
=
WanPreInfer
...
...
lightx2v/models/networks/wan/weights/animate/transformer_weights.py
View file @
74eeb429
...
@@ -26,15 +26,19 @@ class WanAnimateTransformerWeights(WanTransformerWeights):
...
@@ -26,15 +26,19 @@ class WanAnimateTransformerWeights(WanTransformerWeights):
self
.
_add_animate_fuserblock_to_offload_buffers
()
self
.
_add_animate_fuserblock_to_offload_buffers
()
def
_add_animate_fuserblock_to_offload_buffers
(
self
):
def
_add_animate_fuserblock_to_offload_buffers
(
self
):
if
hasattr
(
self
,
"offload_block_buffers"
)
and
self
.
offload_block_buffers
is
not
None
:
if
hasattr
(
self
,
"offload_block_
cuda_
buffers"
)
and
self
.
offload_block_
cuda_
buffers
is
not
None
:
for
i
in
range
(
self
.
offload_blocks_num
):
for
i
in
range
(
self
.
offload_blocks_num
):
self
.
offload_block_buffers
[
i
].
compute_phases
.
append
(
WanAnimateFuserBlock
(
self
.
config
,
0
,
"face_adapter.fuser_blocks"
,
self
.
mm_type
,
is_offload_buffer
=
True
))
self
.
offload_block_cuda_buffers
[
i
].
compute_phases
.
append
(
WanAnimateFuserBlock
(
self
.
config
,
0
,
"face_adapter.fuser_blocks"
,
self
.
mm_type
,
create_cuda_buffer
=
True
))
elif
hasattr
(
self
,
"offload_phase_buffers"
)
and
self
.
offload_phase_buffers
is
not
None
:
if
self
.
lazy_load
:
self
.
offload_phase_buffers
.
append
(
WanAnimateFuserBlock
(
self
.
config
,
0
,
"face_adapter.fuser_blocks"
,
self
.
mm_type
,
is_offload_buffer
=
True
))
self
.
offload_block_cpu_buffers
[
i
].
compute_phases
.
append
(
WanAnimateFuserBlock
(
self
.
config
,
0
,
"face_adapter.fuser_blocks"
,
self
.
mm_type
,
create_cpu_buffer
=
True
))
elif
hasattr
(
self
,
"offload_phase_cuda_buffers"
)
and
self
.
offload_phase_cuda_buffers
is
not
None
:
self
.
offload_phase_cuda_buffers
.
append
(
WanAnimateFuserBlock
(
self
.
config
,
0
,
"face_adapter.fuser_blocks"
,
self
.
mm_type
,
create_cuda_buffer
=
True
))
if
self
.
lazy_load
:
self
.
offload_phase_cpu_buffers
.
append
(
WanAnimateFuserBlock
(
self
.
config
,
0
,
"face_adapter.fuser_blocks"
,
self
.
mm_type
,
create_cpu_buffer
=
True
))
class
WanAnimateFuserBlock
(
WeightModule
):
class
WanAnimateFuserBlock
(
WeightModule
):
def
__init__
(
self
,
config
,
block_index
,
block_prefix
,
mm_type
,
is_offload
_buffer
=
False
):
def
__init__
(
self
,
config
,
block_index
,
block_prefix
,
mm_type
,
create_cuda_buffer
=
False
,
create_cpu
_buffer
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
is_post_adapter
=
True
self
.
is_post_adapter
=
True
...
@@ -53,7 +57,8 @@ class WanAnimateFuserBlock(WeightModule):
...
@@ -53,7 +57,8 @@ class WanAnimateFuserBlock(WeightModule):
MM_WEIGHT_REGISTER
[
mm_type
](
MM_WEIGHT_REGISTER
[
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_kv.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_kv.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_kv.bias"
,
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_kv.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load
,
lazy_load_file
,
lazy_load_file
,
self
.
is_post_adapter
,
self
.
is_post_adapter
,
...
@@ -65,7 +70,8 @@ class WanAnimateFuserBlock(WeightModule):
...
@@ -65,7 +70,8 @@ class WanAnimateFuserBlock(WeightModule):
MM_WEIGHT_REGISTER
[
mm_type
](
MM_WEIGHT_REGISTER
[
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_q.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_q.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_q.bias"
,
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_q.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load
,
lazy_load_file
,
lazy_load_file
,
self
.
is_post_adapter
,
self
.
is_post_adapter
,
...
@@ -76,7 +82,8 @@ class WanAnimateFuserBlock(WeightModule):
...
@@ -76,7 +82,8 @@ class WanAnimateFuserBlock(WeightModule):
MM_WEIGHT_REGISTER
[
mm_type
](
MM_WEIGHT_REGISTER
[
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.linear2.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.linear2.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.linear2.bias"
,
f
"
{
block_prefix
}
.
{
block_index
}
.linear2.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load
,
lazy_load_file
,
lazy_load_file
,
self
.
is_post_adapter
,
self
.
is_post_adapter
,
...
@@ -87,7 +94,8 @@ class WanAnimateFuserBlock(WeightModule):
...
@@ -87,7 +94,8 @@ class WanAnimateFuserBlock(WeightModule):
"q_norm"
,
"q_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"
{
block_prefix
}
.
{
block_index
}
.q_norm.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.q_norm.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load
,
lazy_load_file
,
lazy_load_file
,
self
.
is_post_adapter
,
self
.
is_post_adapter
,
...
@@ -98,7 +106,8 @@ class WanAnimateFuserBlock(WeightModule):
...
@@ -98,7 +106,8 @@ class WanAnimateFuserBlock(WeightModule):
"k_norm"
,
"k_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"
{
block_prefix
}
.
{
block_index
}
.k_norm.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.k_norm.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load
,
lazy_load_file
,
lazy_load_file
,
self
.
is_post_adapter
,
self
.
is_post_adapter
,
...
...
lightx2v/models/networks/wan/weights/audio/transformer_weights.py
View file @
74eeb429
...
@@ -19,6 +19,7 @@ class WanAudioTransformerWeights(WanTransformerWeights):
...
@@ -19,6 +19,7 @@ class WanAudioTransformerWeights(WanTransformerWeights):
self
.
mm_type
,
self
.
mm_type
,
self
.
config
,
self
.
config
,
False
,
False
,
False
,
self
.
blocks
[
i
].
lazy_load
,
self
.
blocks
[
i
].
lazy_load
,
self
.
blocks
[
i
].
lazy_load_file
,
self
.
blocks
[
i
].
lazy_load_file
,
)
)
...
@@ -27,37 +28,66 @@ class WanAudioTransformerWeights(WanTransformerWeights):
...
@@ -27,37 +28,66 @@ class WanAudioTransformerWeights(WanTransformerWeights):
self
.
_add_audio_adapter_ca_to_offload_buffers
()
self
.
_add_audio_adapter_ca_to_offload_buffers
()
def
_add_audio_adapter_ca_to_offload_buffers
(
self
):
def
_add_audio_adapter_ca_to_offload_buffers
(
self
):
if
hasattr
(
self
,
"offload_block_buffers"
)
and
self
.
offload_block_buffers
is
not
None
:
if
hasattr
(
self
,
"offload_block_
cuda_
buffers"
)
and
self
.
offload_block_
cuda_
buffers
is
not
None
:
for
i
in
range
(
self
.
offload_blocks_num
):
for
i
in
range
(
self
.
offload_blocks_num
):
offload_buffer
=
self
.
offload_block_buffers
[
i
]
offload_buffer
=
self
.
offload_block_
cuda_
buffers
[
i
]
adapter_ca
=
WanAudioAdapterCA
(
adapter_ca
=
WanAudioAdapterCA
(
block_index
=
i
,
block_index
=
i
,
block_prefix
=
f
"ca"
,
block_prefix
=
f
"ca"
,
task
=
self
.
task
,
task
=
self
.
task
,
mm_type
=
self
.
mm_type
,
mm_type
=
self
.
mm_type
,
config
=
self
.
config
,
config
=
self
.
config
,
is_offload_buffer
=
True
,
create_cuda_buffer
=
True
,
create_cpu_buffer
=
False
,
lazy_load
=
offload_buffer
.
lazy_load
,
lazy_load
=
offload_buffer
.
lazy_load
,
lazy_load_file
=
offload_buffer
.
lazy_load_file
,
lazy_load_file
=
offload_buffer
.
lazy_load_file
,
)
)
offload_buffer
.
compute_phases
.
append
(
adapter_ca
)
offload_buffer
.
compute_phases
.
append
(
adapter_ca
)
if
self
.
lazy_load
:
offload_buffer
=
self
.
offload_block_cpu_buffers
[
i
]
adapter_ca
=
WanAudioAdapterCA
(
block_index
=
i
,
block_prefix
=
f
"ca"
,
task
=
self
.
task
,
mm_type
=
self
.
mm_type
,
config
=
self
.
config
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
True
,
lazy_load
=
offload_buffer
.
lazy_load
,
lazy_load_file
=
offload_buffer
.
lazy_load_file
,
)
offload_buffer
.
compute_phases
.
append
(
adapter_ca
)
elif
hasattr
(
self
,
"offload_phase_buffers"
)
and
self
.
offload_phase_buffers
is
not
None
:
elif
hasattr
(
self
,
"offload_phase_
cuda_
buffers"
)
and
self
.
offload_phase_
cuda_
buffers
is
not
None
:
adapter_ca
=
WanAudioAdapterCA
(
adapter_ca
=
WanAudioAdapterCA
(
block_index
=
0
,
block_index
=
0
,
block_prefix
=
f
"ca"
,
block_prefix
=
f
"ca"
,
task
=
self
.
task
,
task
=
self
.
task
,
mm_type
=
self
.
mm_type
,
mm_type
=
self
.
mm_type
,
config
=
self
.
config
,
config
=
self
.
config
,
is_offload_buffer
=
True
,
create_cuda_buffer
=
True
,
create_cpu_buffer
=
False
,
lazy_load
=
self
.
blocks
[
0
].
lazy_load
,
lazy_load
=
self
.
blocks
[
0
].
lazy_load
,
lazy_load_file
=
self
.
blocks
[
0
].
lazy_load_file
,
lazy_load_file
=
self
.
blocks
[
0
].
lazy_load_file
,
)
)
self
.
offload_phase_buffers
.
append
(
adapter_ca
)
self
.
offload_phase_cuda_buffers
.
append
(
adapter_ca
)
if
self
.
lazy_load
:
adapter_ca
=
WanAudioAdapterCA
(
block_index
=
0
,
block_prefix
=
f
"ca"
,
task
=
self
.
task
,
mm_type
=
self
.
mm_type
,
config
=
self
.
config
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
True
,
lazy_load
=
self
.
blocks
[
0
].
lazy_load
,
lazy_load_file
=
self
.
blocks
[
0
].
lazy_load_file
,
)
self
.
offload_phase_cpu_buffers
.
append
(
adapter_ca
)
class
WanAudioAdapterCA
(
WeightModule
):
class
WanAudioAdapterCA
(
WeightModule
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
is_offload
_buffer
,
lazy_load
,
lazy_load_file
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
create_cuda_buffer
,
create_cpu
_buffer
,
lazy_load
,
lazy_load_file
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
mm_type
=
mm_type
...
@@ -71,7 +101,8 @@ class WanAudioAdapterCA(WeightModule):
...
@@ -71,7 +101,8 @@ class WanAudioAdapterCA(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.to_q.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.to_q.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.to_q.bias"
,
f
"
{
block_prefix
}
.
{
block_index
}
.to_q.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -82,7 +113,8 @@ class WanAudioAdapterCA(WeightModule):
...
@@ -82,7 +113,8 @@ class WanAudioAdapterCA(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.to_kv.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.to_kv.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.to_kv.bias"
,
f
"
{
block_prefix
}
.
{
block_index
}
.to_kv.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -93,7 +125,8 @@ class WanAudioAdapterCA(WeightModule):
...
@@ -93,7 +125,8 @@ class WanAudioAdapterCA(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.to_out.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.to_out.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.to_out.bias"
,
f
"
{
block_prefix
}
.
{
block_index
}
.to_out.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -104,7 +137,8 @@ class WanAudioAdapterCA(WeightModule):
...
@@ -104,7 +137,8 @@ class WanAudioAdapterCA(WeightModule):
LN_WEIGHT_REGISTER
[
"Default"
](
LN_WEIGHT_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
block_index
}
.norm_kv.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.norm_kv.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.norm_kv.bias"
,
f
"
{
block_prefix
}
.
{
block_index
}
.norm_kv.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -119,7 +153,8 @@ class WanAudioAdapterCA(WeightModule):
...
@@ -119,7 +153,8 @@ class WanAudioAdapterCA(WeightModule):
"shift_scale_gate"
,
"shift_scale_gate"
,
TENSOR_REGISTER
[
"Default"
](
TENSOR_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
block_index
}
.shift_scale_gate"
,
f
"
{
block_prefix
}
.
{
block_index
}
.shift_scale_gate"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
...
lightx2v/models/networks/wan/weights/matrix_game2/transformer_weights.py
100644 → 100755
View file @
74eeb429
import
os
from
safetensors
import
safe_open
from
lightx2v.common.modules.weight_module
import
WeightModule
,
WeightModuleList
from
lightx2v.common.modules.weight_module
import
WeightModule
,
WeightModuleList
from
lightx2v.models.networks.wan.weights.transformer_weights
import
(
from
lightx2v.models.networks.wan.weights.transformer_weights
import
(
WanFFN
,
WanFFN
,
...
@@ -31,9 +27,9 @@ class WanActionTransformerWeights(WeightModule):
...
@@ -31,9 +27,9 @@ class WanActionTransformerWeights(WeightModule):
block_list
=
[]
block_list
=
[]
for
i
in
range
(
self
.
blocks_num
):
for
i
in
range
(
self
.
blocks_num
):
if
i
in
action_blocks
:
if
i
in
action_blocks
:
block_list
.
append
(
WanTransformerActionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
"blocks"
))
block_list
.
append
(
WanTransformerActionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
))
else
:
else
:
block_list
.
append
(
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
False
,
"blocks"
))
block_list
.
append
(
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
))
self
.
blocks
=
WeightModuleList
(
block_list
)
self
.
blocks
=
WeightModuleList
(
block_list
)
self
.
add_module
(
"blocks"
,
self
.
blocks
)
self
.
add_module
(
"blocks"
,
self
.
blocks
)
...
@@ -42,11 +38,6 @@ class WanActionTransformerWeights(WeightModule):
...
@@ -42,11 +38,6 @@ class WanActionTransformerWeights(WeightModule):
self
.
add_module
(
"head"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"head.head.weight"
,
"head.head.bias"
))
self
.
add_module
(
"head"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"head.head.weight"
,
"head.head.bias"
))
self
.
register_parameter
(
"head_modulation"
,
TENSOR_REGISTER
[
"Default"
](
"head.modulation"
))
self
.
register_parameter
(
"head_modulation"
,
TENSOR_REGISTER
[
"Default"
](
"head.modulation"
))
def
clear
(
self
):
for
block
in
self
.
blocks
:
for
phase
in
block
.
compute_phases
:
phase
.
clear
()
def
non_block_weights_to_cuda
(
self
):
def
non_block_weights_to_cuda
(
self
):
self
.
norm
.
to_cuda
()
self
.
norm
.
to_cuda
()
self
.
head
.
to_cuda
()
self
.
head
.
to_cuda
()
...
@@ -66,34 +57,16 @@ class WanTransformerActionBlock(WeightModule):
...
@@ -66,34 +57,16 @@ class WanTransformerActionBlock(WeightModule):
self
.
task
=
task
self
.
task
=
task
self
.
config
=
config
self
.
config
=
config
self
.
quant_method
=
config
.
get
(
"quant_method"
,
None
)
self
.
quant_method
=
config
.
get
(
"quant_method"
,
None
)
assert
not
self
.
config
.
get
(
"lazy_load"
,
False
)
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
if
self
.
lazy_load
:
lazy_load_path
=
os
.
path
.
join
(
self
.
config
[
"dit_quantized_ckpt"
],
f
"block_
{
block_index
}
.safetensors"
)
self
.
lazy_load_file
=
safe_open
(
lazy_load_path
,
framework
=
"pt"
,
device
=
"cpu"
)
else
:
self
.
lazy_load_file
=
None
self
.
compute_phases
=
WeightModuleList
(
self
.
compute_phases
=
WeightModuleList
(
[
[
WanSelfAttention
(
WanSelfAttention
(
block_index
,
block_prefix
,
task
,
mm_type
,
config
),
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
False
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
WanActionCrossAttention
(
WanActionCrossAttention
(
block_index
,
block_index
,
block_prefix
,
block_prefix
,
task
,
task
,
mm_type
,
mm_type
,
config
,
config
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
),
WanActionModule
(
WanActionModule
(
block_index
,
block_index
,
...
@@ -101,8 +74,6 @@ class WanTransformerActionBlock(WeightModule):
...
@@ -101,8 +74,6 @@ class WanTransformerActionBlock(WeightModule):
task
,
task
,
mm_type
,
mm_type
,
config
,
config
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
),
WanFFN
(
WanFFN
(
block_index
,
block_index
,
...
@@ -110,9 +81,6 @@ class WanTransformerActionBlock(WeightModule):
...
@@ -110,9 +81,6 @@ class WanTransformerActionBlock(WeightModule):
task
,
task
,
mm_type
,
mm_type
,
config
,
config
,
False
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
),
]
]
)
)
...
@@ -121,7 +89,7 @@ class WanTransformerActionBlock(WeightModule):
...
@@ -121,7 +89,7 @@ class WanTransformerActionBlock(WeightModule):
class
WanActionModule
(
WeightModule
):
class
WanActionModule
(
WeightModule
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
lazy_load
,
lazy_load_file
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
mm_type
=
mm_type
...
@@ -129,9 +97,6 @@ class WanActionModule(WeightModule):
...
@@ -129,9 +97,6 @@ class WanActionModule(WeightModule):
self
.
config
=
config
self
.
config
=
config
self
.
quant_method
=
config
.
get
(
"quant_method"
,
None
)
self
.
quant_method
=
config
.
get
(
"quant_method"
,
None
)
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
attn_rms_type
=
"self_forcing"
self
.
attn_rms_type
=
"self_forcing"
self
.
add_module
(
self
.
add_module
(
...
@@ -139,8 +104,6 @@ class WanActionModule(WeightModule):
...
@@ -139,8 +104,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.keyboard_embed.0.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.keyboard_embed.0.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.keyboard_embed.0.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.keyboard_embed.0.bias"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
),
)
)
self
.
add_module
(
self
.
add_module
(
...
@@ -148,8 +111,6 @@ class WanActionModule(WeightModule):
...
@@ -148,8 +111,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.keyboard_embed.2.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.keyboard_embed.2.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.keyboard_embed.2.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.keyboard_embed.2.bias"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
),
)
)
...
@@ -158,8 +119,6 @@ class WanActionModule(WeightModule):
...
@@ -158,8 +119,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.proj_keyboard.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.proj_keyboard.weight"
,
bias_name
=
None
,
bias_name
=
None
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
),
),
)
)
...
@@ -168,8 +127,6 @@ class WanActionModule(WeightModule):
...
@@ -168,8 +127,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.keyboard_attn_kv.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.keyboard_attn_kv.weight"
,
bias_name
=
None
,
bias_name
=
None
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
),
),
)
)
...
@@ -180,8 +137,6 @@ class WanActionModule(WeightModule):
...
@@ -180,8 +137,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.mouse_attn_q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.mouse_attn_q.weight"
,
bias_name
=
None
,
bias_name
=
None
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
),
),
)
)
...
@@ -191,8 +146,6 @@ class WanActionModule(WeightModule):
...
@@ -191,8 +146,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.t_qkv.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.t_qkv.weight"
,
bias_name
=
None
,
bias_name
=
None
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
),
),
)
)
...
@@ -201,8 +154,6 @@ class WanActionModule(WeightModule):
...
@@ -201,8 +154,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.proj_mouse.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.proj_mouse.weight"
,
bias_name
=
None
,
bias_name
=
None
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
),
),
)
)
...
@@ -211,8 +162,6 @@ class WanActionModule(WeightModule):
...
@@ -211,8 +162,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.mouse_mlp.0.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.mouse_mlp.0.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.mouse_mlp.0.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.mouse_mlp.0.bias"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
),
)
)
self
.
add_module
(
self
.
add_module
(
...
@@ -220,8 +169,6 @@ class WanActionModule(WeightModule):
...
@@ -220,8 +169,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.mouse_mlp.2.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.mouse_mlp.2.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.mouse_mlp.2.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.mouse_mlp.2.bias"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
),
)
)
self
.
add_module
(
self
.
add_module
(
...
@@ -229,22 +176,18 @@ class WanActionModule(WeightModule):
...
@@ -229,22 +176,18 @@ class WanActionModule(WeightModule):
LN_WEIGHT_REGISTER
[
self
.
mm_type
](
LN_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.mouse_mlp.3.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.mouse_mlp.3.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.mouse_mlp.3.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.mouse_mlp.3.bias"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
eps
=
1e-6
,
eps
=
1e-6
,
),
),
)
)
class
WanActionCrossAttention
(
WeightModule
):
class
WanActionCrossAttention
(
WeightModule
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
lazy_load
,
lazy_load_file
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
mm_type
=
mm_type
self
.
task
=
task
self
.
task
=
task
self
.
config
=
config
self
.
config
=
config
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
if
self
.
config
.
get
(
"sf_config"
,
False
):
if
self
.
config
.
get
(
"sf_config"
,
False
):
self
.
attn_rms_type
=
"self_forcing"
self
.
attn_rms_type
=
"self_forcing"
...
@@ -256,8 +199,6 @@ class WanActionCrossAttention(WeightModule):
...
@@ -256,8 +199,6 @@ class WanActionCrossAttention(WeightModule):
LN_WEIGHT_REGISTER
[
"Default"
](
LN_WEIGHT_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.norm3.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.norm3.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.norm3.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.norm3.bias"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
),
)
)
self
.
add_module
(
self
.
add_module
(
...
@@ -265,8 +206,6 @@ class WanActionCrossAttention(WeightModule):
...
@@ -265,8 +206,6 @@ class WanActionCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.q.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.q.bias"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
),
)
)
self
.
add_module
(
self
.
add_module
(
...
@@ -274,8 +213,6 @@ class WanActionCrossAttention(WeightModule):
...
@@ -274,8 +213,6 @@ class WanActionCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k.bias"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
),
)
)
self
.
add_module
(
self
.
add_module
(
...
@@ -283,8 +220,6 @@ class WanActionCrossAttention(WeightModule):
...
@@ -283,8 +220,6 @@ class WanActionCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v.bias"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
),
)
)
self
.
add_module
(
self
.
add_module
(
...
@@ -292,24 +227,18 @@ class WanActionCrossAttention(WeightModule):
...
@@ -292,24 +227,18 @@ class WanActionCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.o.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.o.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.o.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.o.bias"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
),
)
)
self
.
add_module
(
self
.
add_module
(
"cross_attn_norm_q"
,
"cross_attn_norm_q"
,
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.norm_q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.norm_q.weight"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
),
)
)
self
.
add_module
(
self
.
add_module
(
"cross_attn_norm_k"
,
"cross_attn_norm_k"
,
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.norm_k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.norm_k.weight"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
),
)
)
self
.
add_module
(
"cross_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"cross_attn_1_type"
]]())
self
.
add_module
(
"cross_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"cross_attn_1_type"
]]())
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
74eeb429
import
os
from
safetensors
import
safe_open
from
safetensors
import
safe_open
from
lightx2v.common.modules.weight_module
import
WeightModule
,
WeightModuleList
from
lightx2v.common.modules.weight_module
import
WeightModule
,
WeightModuleList
...
@@ -13,7 +11,7 @@ from lightx2v.utils.registry_factory import (
...
@@ -13,7 +11,7 @@ from lightx2v.utils.registry_factory import (
class
WanTransformerWeights
(
WeightModule
):
class
WanTransformerWeights
(
WeightModule
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
lazy_load_path
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
blocks_num
=
config
[
"num_layers"
]
self
.
blocks_num
=
config
[
"num_layers"
]
self
.
task
=
config
[
"task"
]
self
.
task
=
config
[
"task"
]
...
@@ -23,7 +21,27 @@ class WanTransformerWeights(WeightModule):
...
@@ -23,7 +21,27 @@ class WanTransformerWeights(WeightModule):
assert
config
.
get
(
"dit_quantized"
)
is
True
assert
config
.
get
(
"dit_quantized"
)
is
True
if
config
.
get
(
"do_mm_calib"
,
False
):
if
config
.
get
(
"do_mm_calib"
,
False
):
self
.
mm_type
=
"Calib"
self
.
mm_type
=
"Calib"
self
.
blocks
=
WeightModuleList
([
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
)
for
i
in
range
(
self
.
blocks_num
)])
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
if
not
self
.
lazy_load
:
self
.
lazy_load_file
=
None
else
:
self
.
lazy_load_file
=
safe_open
(
lazy_load_path
,
framework
=
"pt"
,
device
=
"cpu"
)
self
.
blocks
=
WeightModuleList
(
[
WanTransformerAttentionBlock
(
block_index
=
i
,
task
=
self
.
task
,
mm_type
=
self
.
mm_type
,
config
=
self
.
config
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
)
for
i
in
range
(
self
.
blocks_num
)
]
)
self
.
register_offload_buffers
(
config
)
self
.
register_offload_buffers
(
config
)
self
.
add_module
(
"blocks"
,
self
.
blocks
)
self
.
add_module
(
"blocks"
,
self
.
blocks
)
...
@@ -36,35 +54,74 @@ class WanTransformerWeights(WeightModule):
...
@@ -36,35 +54,74 @@ class WanTransformerWeights(WeightModule):
if
config
[
"cpu_offload"
]:
if
config
[
"cpu_offload"
]:
if
config
[
"offload_granularity"
]
==
"block"
:
if
config
[
"offload_granularity"
]
==
"block"
:
self
.
offload_blocks_num
=
2
self
.
offload_blocks_num
=
2
self
.
offload_block_buffers
=
WeightModuleList
(
self
.
offload_block_
cuda_
buffers
=
WeightModuleList
(
[
[
WanTransformerAttentionBlock
(
WanTransformerAttentionBlock
(
i
,
block_index
=
i
,
self
.
task
,
task
=
self
.
task
,
self
.
mm_type
,
mm_type
=
self
.
mm_type
,
self
.
config
,
config
=
self
.
config
,
True
,
create_cuda_buffer
=
True
,
create_cpu_buffer
=
False
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
)
)
for
i
in
range
(
self
.
offload_blocks_num
)
for
i
in
range
(
self
.
offload_blocks_num
)
]
]
)
)
self
.
add_module
(
"offload_block_buffers"
,
self
.
offload_block_buffers
)
self
.
add_module
(
"offload_block_cuda_buffers"
,
self
.
offload_block_cuda_buffers
)
self
.
offload_phase_buffers
=
None
self
.
offload_phase_cuda_buffers
=
None
if
self
.
lazy_load
:
self
.
offload_blocks_num
=
2
self
.
offload_block_cpu_buffers
=
WeightModuleList
(
[
WanTransformerAttentionBlock
(
block_index
=
i
,
task
=
self
.
task
,
mm_type
=
self
.
mm_type
,
config
=
self
.
config
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
True
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
)
for
i
in
range
(
self
.
offload_blocks_num
)
]
)
self
.
add_module
(
"offload_block_cpu_buffers"
,
self
.
offload_block_cpu_buffers
)
self
.
offload_phase_cpu_buffers
=
None
elif
config
[
"offload_granularity"
]
==
"phase"
:
elif
config
[
"offload_granularity"
]
==
"phase"
:
self
.
offload_phase_buffers
=
WanTransformerAttentionBlock
(
self
.
offload_phase_cuda_buffers
=
WanTransformerAttentionBlock
(
0
,
block_index
=
0
,
self
.
task
,
task
=
self
.
task
,
self
.
mm_type
,
mm_type
=
self
.
mm_type
,
self
.
config
,
config
=
self
.
config
,
True
,
create_cuda_buffer
=
True
,
create_cpu_buffer
=
False
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
).
compute_phases
).
compute_phases
self
.
add_module
(
"offload_phase_buffers"
,
self
.
offload_phase_buffers
)
self
.
add_module
(
"offload_phase_cuda_buffers"
,
self
.
offload_phase_cuda_buffers
)
self
.
offload_block_buffers
=
None
self
.
offload_block_cuda_buffers
=
None
if
self
.
lazy_load
:
def
clear
(
self
):
self
.
offload_phase_cpu_buffers
=
WanTransformerAttentionBlock
(
for
block
in
self
.
blocks
:
block_index
=
0
,
for
phase
in
block
.
compute_phases
:
task
=
self
.
task
,
phase
.
clear
()
mm_type
=
self
.
mm_type
,
config
=
self
.
config
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
True
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
).
compute_phases
self
.
add_module
(
"offload_phase_cpu_buffers"
,
self
.
offload_phase_cpu_buffers
)
self
.
offload_block_cpu_buffers
=
None
def
non_block_weights_to_cuda
(
self
):
def
non_block_weights_to_cuda
(
self
):
self
.
norm
.
to_cuda
()
self
.
norm
.
to_cuda
()
...
@@ -84,23 +141,23 @@ class WanTransformerAttentionBlock(WeightModule):
...
@@ -84,23 +141,23 @@ class WanTransformerAttentionBlock(WeightModule):
task
,
task
,
mm_type
,
mm_type
,
config
,
config
,
is_offload_buffer
=
False
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
block_prefix
=
"blocks"
,
block_prefix
=
"blocks"
,
lazy_load
=
False
,
lazy_load_file
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
mm_type
=
mm_type
self
.
task
=
task
self
.
task
=
task
self
.
config
=
config
self
.
config
=
config
self
.
is_offload_buffer
=
is_offload_buffer
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cpu_buffer
=
create_cpu_buffer
self
.
quant_method
=
config
.
get
(
"quant_method"
,
None
)
self
.
quant_method
=
config
.
get
(
"quant_method"
,
None
)
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
self
.
lazy_load
=
lazy_load
if
self
.
lazy_load
:
self
.
lazy_load_file
=
lazy_load_file
lazy_load_path
=
os
.
path
.
join
(
self
.
config
[
"dit_quantized_ckpt"
],
f
"block_
{
block_index
}
.safetensors"
)
self
.
lazy_load_file
=
safe_open
(
lazy_load_path
,
framework
=
"pt"
,
device
=
"cpu"
)
else
:
self
.
lazy_load_file
=
None
self
.
compute_phases
=
WeightModuleList
(
self
.
compute_phases
=
WeightModuleList
(
[
[
...
@@ -110,7 +167,8 @@ class WanTransformerAttentionBlock(WeightModule):
...
@@ -110,7 +167,8 @@ class WanTransformerAttentionBlock(WeightModule):
task
,
task
,
mm_type
,
mm_type
,
config
,
config
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -120,7 +178,8 @@ class WanTransformerAttentionBlock(WeightModule):
...
@@ -120,7 +178,8 @@ class WanTransformerAttentionBlock(WeightModule):
task
,
task
,
mm_type
,
mm_type
,
config
,
config
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -130,7 +189,8 @@ class WanTransformerAttentionBlock(WeightModule):
...
@@ -130,7 +189,8 @@ class WanTransformerAttentionBlock(WeightModule):
task
,
task
,
mm_type
,
mm_type
,
config
,
config
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -148,9 +208,10 @@ class WanSelfAttention(WeightModule):
...
@@ -148,9 +208,10 @@ class WanSelfAttention(WeightModule):
task
,
task
,
mm_type
,
mm_type
,
config
,
config
,
is_offload_buffer
,
create_cuda_buffer
=
False
,
lazy_load
,
create_cpu_buffer
=
False
,
lazy_load_file
,
lazy_load
=
False
,
lazy_load_file
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
...
@@ -171,7 +232,8 @@ class WanSelfAttention(WeightModule):
...
@@ -171,7 +232,8 @@ class WanSelfAttention(WeightModule):
"modulation"
,
"modulation"
,
TENSOR_REGISTER
[
"Default"
](
TENSOR_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.modulation"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.modulation"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -187,7 +249,8 @@ class WanSelfAttention(WeightModule):
...
@@ -187,7 +249,8 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.q.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.q.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -198,7 +261,8 @@ class WanSelfAttention(WeightModule):
...
@@ -198,7 +261,8 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.k.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.k.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -208,7 +272,8 @@ class WanSelfAttention(WeightModule):
...
@@ -208,7 +272,8 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.v.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.v.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -218,7 +283,8 @@ class WanSelfAttention(WeightModule):
...
@@ -218,7 +283,8 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.o.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.o.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.o.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.o.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -227,7 +293,8 @@ class WanSelfAttention(WeightModule):
...
@@ -227,7 +293,8 @@ class WanSelfAttention(WeightModule):
"self_attn_norm_q"
,
"self_attn_norm_q"
,
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.norm_q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.norm_q.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -236,7 +303,8 @@ class WanSelfAttention(WeightModule):
...
@@ -236,7 +303,8 @@ class WanSelfAttention(WeightModule):
"self_attn_norm_k"
,
"self_attn_norm_k"
,
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.norm_k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.norm_k.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -278,7 +346,8 @@ class WanSelfAttention(WeightModule):
...
@@ -278,7 +346,8 @@ class WanSelfAttention(WeightModule):
"smooth_norm1_weight"
,
"smooth_norm1_weight"
,
TENSOR_REGISTER
[
"Default"
](
TENSOR_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.affine_norm1.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.affine_norm1.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -287,7 +356,8 @@ class WanSelfAttention(WeightModule):
...
@@ -287,7 +356,8 @@ class WanSelfAttention(WeightModule):
"smooth_norm1_bias"
,
"smooth_norm1_bias"
,
TENSOR_REGISTER
[
"Default"
](
TENSOR_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.affine_norm1.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.affine_norm1.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -302,9 +372,10 @@ class WanCrossAttention(WeightModule):
...
@@ -302,9 +372,10 @@ class WanCrossAttention(WeightModule):
task
,
task
,
mm_type
,
mm_type
,
config
,
config
,
is_offload_buffer
,
create_cuda_buffer
=
False
,
lazy_load
,
create_cpu_buffer
=
False
,
lazy_load_file
,
lazy_load
=
False
,
lazy_load_file
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
...
@@ -324,7 +395,8 @@ class WanCrossAttention(WeightModule):
...
@@ -324,7 +395,8 @@ class WanCrossAttention(WeightModule):
LN_WEIGHT_REGISTER
[
"Default"
](
LN_WEIGHT_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.norm3.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.norm3.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.norm3.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.norm3.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -334,7 +406,8 @@ class WanCrossAttention(WeightModule):
...
@@ -334,7 +406,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.q.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.q.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -344,7 +417,8 @@ class WanCrossAttention(WeightModule):
...
@@ -344,7 +417,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -354,7 +428,8 @@ class WanCrossAttention(WeightModule):
...
@@ -354,7 +428,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -364,7 +439,8 @@ class WanCrossAttention(WeightModule):
...
@@ -364,7 +439,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.o.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.o.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.o.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.o.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -373,7 +449,8 @@ class WanCrossAttention(WeightModule):
...
@@ -373,7 +449,8 @@ class WanCrossAttention(WeightModule):
"cross_attn_norm_q"
,
"cross_attn_norm_q"
,
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.norm_q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.norm_q.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -382,7 +459,8 @@ class WanCrossAttention(WeightModule):
...
@@ -382,7 +459,8 @@ class WanCrossAttention(WeightModule):
"cross_attn_norm_k"
,
"cross_attn_norm_k"
,
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.norm_k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.norm_k.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -395,7 +473,8 @@ class WanCrossAttention(WeightModule):
...
@@ -395,7 +473,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k_img.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k_img.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k_img.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k_img.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -405,7 +484,8 @@ class WanCrossAttention(WeightModule):
...
@@ -405,7 +484,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v_img.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v_img.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v_img.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v_img.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -414,7 +494,8 @@ class WanCrossAttention(WeightModule):
...
@@ -414,7 +494,8 @@ class WanCrossAttention(WeightModule):
"cross_attn_norm_k_img"
,
"cross_attn_norm_k_img"
,
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.norm_k_img.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.norm_k_img.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -430,9 +511,10 @@ class WanFFN(WeightModule):
...
@@ -430,9 +511,10 @@ class WanFFN(WeightModule):
task
,
task
,
mm_type
,
mm_type
,
config
,
config
,
is_offload_buffer
,
create_cuda_buffer
=
False
,
lazy_load
,
create_cpu_buffer
=
False
,
lazy_load_file
,
lazy_load
=
False
,
lazy_load_file
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
...
@@ -453,7 +535,8 @@ class WanFFN(WeightModule):
...
@@ -453,7 +535,8 @@ class WanFFN(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.ffn.0.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.ffn.0.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.ffn.0.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.ffn.0.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -463,7 +546,8 @@ class WanFFN(WeightModule):
...
@@ -463,7 +546,8 @@ class WanFFN(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.ffn.2.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.ffn.2.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.ffn.2.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.ffn.2.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -474,7 +558,8 @@ class WanFFN(WeightModule):
...
@@ -474,7 +558,8 @@ class WanFFN(WeightModule):
"smooth_norm2_weight"
,
"smooth_norm2_weight"
,
TENSOR_REGISTER
[
"Default"
](
TENSOR_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.affine_norm3.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.affine_norm3.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -483,7 +568,8 @@ class WanFFN(WeightModule):
...
@@ -483,7 +568,8 @@ class WanFFN(WeightModule):
"smooth_norm2_bias"
,
"smooth_norm2_bias"
,
TENSOR_REGISTER
[
"Default"
](
TENSOR_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.affine_norm3.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.affine_norm3.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
...
lightx2v/models/networks/wan/weights/vace/transformer_weights.py
View file @
74eeb429
...
@@ -15,7 +15,7 @@ class WanVaceTransformerWeights(WanTransformerWeights):
...
@@ -15,7 +15,7 @@ class WanVaceTransformerWeights(WanTransformerWeights):
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
register_offload_buffers
(
config
)
self
.
register_offload_buffers
(
config
)
self
.
vace_blocks
=
WeightModuleList
(
self
.
vace_blocks
=
WeightModuleList
(
[
WanVaceTransformerAttentionBlock
(
self
.
config
[
"vace_layers"
][
i
],
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
False
,
"vace_blocks"
)
for
i
in
range
(
len
(
self
.
config
[
"vace_layers"
]))]
[
WanVaceTransformerAttentionBlock
(
self
.
config
[
"vace_layers"
][
i
],
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
False
,
False
,
"vace_blocks"
)
for
i
in
range
(
len
(
self
.
config
[
"vace_layers"
]))]
)
)
self
.
add_module
(
"vace_blocks"
,
self
.
vace_blocks
)
self
.
add_module
(
"vace_blocks"
,
self
.
vace_blocks
)
self
.
add_module
(
self
.
add_module
(
...
@@ -27,23 +27,17 @@ class WanVaceTransformerWeights(WanTransformerWeights):
...
@@ -27,23 +27,17 @@ class WanVaceTransformerWeights(WanTransformerWeights):
super
().
register_offload_buffers
(
config
)
super
().
register_offload_buffers
(
config
)
if
config
[
"cpu_offload"
]:
if
config
[
"cpu_offload"
]:
if
config
[
"offload_granularity"
]
==
"block"
:
if
config
[
"offload_granularity"
]
==
"block"
:
self
.
vace_offload_block_buffers
=
WeightModuleList
(
self
.
vace_offload_block_
cuda_
buffers
=
WeightModuleList
(
[
[
WanVaceTransformerAttentionBlock
(
self
.
config
[
"vace_layers"
][
0
],
0
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
True
,
"vace_blocks"
),
WanVaceTransformerAttentionBlock
(
self
.
config
[
"vace_layers"
][
0
],
0
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
True
,
False
,
"vace_blocks"
),
WanVaceTransformerAttentionBlock
(
self
.
config
[
"vace_layers"
][
0
],
0
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
True
,
"vace_blocks"
),
WanVaceTransformerAttentionBlock
(
self
.
config
[
"vace_layers"
][
0
],
0
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
True
,
False
,
"vace_blocks"
),
]
]
)
)
self
.
add_module
(
"vace_offload_block_buffers"
,
self
.
vace_offload_block_buffers
)
self
.
add_module
(
"vace_offload_block_
cuda_
buffers"
,
self
.
vace_offload_block_
cuda_
buffers
)
self
.
vace_offload_phase_buffers
=
None
self
.
vace_offload_phase_
cuda_
buffers
=
None
elif
config
[
"offload_granularity"
]
==
"phase"
:
elif
config
[
"offload_granularity"
]
==
"phase"
:
raise
NotImplementedError
raise
NotImplementedError
def
clear
(
self
):
super
().
clear
()
for
vace_block
in
self
.
vace_blocks
:
for
vace_phase
in
vace_block
.
compute_phases
:
vace_phase
.
clear
()
def
non_block_weights_to_cuda
(
self
):
def
non_block_weights_to_cuda
(
self
):
super
().
non_block_weights_to_cuda
()
super
().
non_block_weights_to_cuda
()
self
.
vace_patch_embedding
.
to_cuda
()
self
.
vace_patch_embedding
.
to_cuda
()
...
@@ -54,15 +48,16 @@ class WanVaceTransformerWeights(WanTransformerWeights):
...
@@ -54,15 +48,16 @@ class WanVaceTransformerWeights(WanTransformerWeights):
class
WanVaceTransformerAttentionBlock
(
WanTransformerAttentionBlock
):
class
WanVaceTransformerAttentionBlock
(
WanTransformerAttentionBlock
):
def
__init__
(
self
,
base_block_idx
,
block_index
,
task
,
mm_type
,
config
,
is_offload
_buffer
,
block_prefix
):
def
__init__
(
self
,
base_block_idx
,
block_index
,
task
,
mm_type
,
config
,
create_cuda_buffer
,
create_cpu
_buffer
,
block_prefix
):
super
().
__init__
(
block_index
,
task
,
mm_type
,
config
,
is_offload
_buffer
,
block_prefix
)
super
().
__init__
(
block_index
,
task
,
mm_type
,
config
,
create_cuda_buffer
,
create_cpu
_buffer
,
block_prefix
)
if
base_block_idx
==
0
:
if
base_block_idx
==
0
:
self
.
compute_phases
[
0
].
add_module
(
self
.
compute_phases
[
0
].
add_module
(
"before_proj"
,
"before_proj"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.before_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.before_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.before_proj.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.before_proj.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -73,7 +68,8 @@ class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock):
...
@@ -73,7 +68,8 @@ class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.after_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.after_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.after_proj.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.after_proj.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
...
lightx2v/models/runners/default_runner.py
View file @
74eeb429
...
@@ -41,7 +41,8 @@ class DefaultRunner(BaseRunner):
...
@@ -41,7 +41,8 @@ class DefaultRunner(BaseRunner):
self
.
load_model
()
self
.
load_model
()
elif
self
.
config
.
get
(
"lazy_load"
,
False
):
elif
self
.
config
.
get
(
"lazy_load"
,
False
):
assert
self
.
config
.
get
(
"cpu_offload"
,
False
)
assert
self
.
config
.
get
(
"cpu_offload"
,
False
)
self
.
model
.
set_scheduler
(
self
.
scheduler
)
# set scheduler to model
if
hasattr
(
self
,
"model"
):
self
.
model
.
set_scheduler
(
self
.
scheduler
)
# set scheduler to model
if
self
.
config
[
"task"
]
==
"i2v"
:
if
self
.
config
[
"task"
]
==
"i2v"
:
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_i2v
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_i2v
elif
self
.
config
[
"task"
]
==
"flf2v"
:
elif
self
.
config
[
"task"
]
==
"flf2v"
:
...
@@ -184,11 +185,6 @@ class DefaultRunner(BaseRunner):
...
@@ -184,11 +185,6 @@ class DefaultRunner(BaseRunner):
del
self
.
inputs
del
self
.
inputs
self
.
input_info
=
None
self
.
input_info
=
None
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
hasattr
(
self
.
model
.
transformer_infer
,
"weights_stream_mgr"
):
self
.
model
.
transformer_infer
.
weights_stream_mgr
.
clear
()
if
hasattr
(
self
.
model
.
transformer_weights
,
"clear"
):
self
.
model
.
transformer_weights
.
clear
()
self
.
model
.
pre_weight
.
clear
()
del
self
.
model
del
self
.
model
if
self
.
config
.
get
(
"do_mm_calib"
,
False
):
if
self
.
config
.
get
(
"do_mm_calib"
,
False
):
calib_path
=
os
.
path
.
join
(
os
.
getcwd
(),
"calib.pt"
)
calib_path
=
os
.
path
.
join
(
os
.
getcwd
(),
"calib.pt"
)
...
@@ -279,6 +275,7 @@ class DefaultRunner(BaseRunner):
...
@@ -279,6 +275,7 @@ class DefaultRunner(BaseRunner):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
model
=
self
.
load_transformer
()
self
.
model
=
self
.
load_transformer
()
self
.
model
.
set_scheduler
(
self
.
scheduler
)
self
.
model
.
scheduler
.
prepare
(
seed
=
self
.
input_info
.
seed
,
latent_shape
=
self
.
input_info
.
latent_shape
,
image_encoder_output
=
self
.
inputs
[
"image_encoder_output"
])
self
.
model
.
scheduler
.
prepare
(
seed
=
self
.
input_info
.
seed
,
latent_shape
=
self
.
input_info
.
latent_shape
,
image_encoder_output
=
self
.
inputs
[
"image_encoder_output"
])
if
self
.
config
.
get
(
"model_cls"
)
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]:
if
self
.
config
.
get
(
"model_cls"
)
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]:
...
...
lightx2v/models/runners/wan/wan_animate_runner.py
View file @
74eeb429
...
@@ -24,6 +24,7 @@ from lightx2v.utils.envs import *
...
@@ -24,6 +24,7 @@ from lightx2v.utils.envs import *
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
load_weights
,
remove_substrings_from_keys
from
lightx2v.utils.utils
import
load_weights
,
remove_substrings_from_keys
from
lightx2v_platform.base.global_var
import
AI_DEVICE
@
RUNNER_REGISTER
(
"wan2.2_animate"
)
@
RUNNER_REGISTER
(
"wan2.2_animate"
)
...
@@ -182,7 +183,7 @@ class WanAnimateRunner(WanRunner):
...
@@ -182,7 +183,7 @@ class WanAnimateRunner(WanRunner):
],
],
dim
=
1
,
dim
=
1
,
)
)
.
cuda
(
)
.
to
(
AI_DEVICE
)
.
unsqueeze
(
0
)
.
unsqueeze
(
0
)
)
)
mask_pixel_values
=
1
-
mask_pixel_values
mask_pixel_values
=
1
-
mask_pixel_values
...
@@ -210,7 +211,7 @@ class WanAnimateRunner(WanRunner):
...
@@ -210,7 +211,7 @@ class WanAnimateRunner(WanRunner):
],
],
dim
=
1
,
dim
=
1
,
)
)
.
cuda
(
)
.
to
(
AI_DEVICE
)
.
unsqueeze
(
0
)
.
unsqueeze
(
0
)
)
)
msk_reft
=
self
.
get_i2v_mask
(
self
.
latent_t
,
self
.
latent_h
,
self
.
latent_w
,
self
.
mask_reft_len
)
msk_reft
=
self
.
get_i2v_mask
(
self
.
latent_t
,
self
.
latent_h
,
self
.
latent_w
,
self
.
mask_reft_len
)
...
@@ -330,7 +331,7 @@ class WanAnimateRunner(WanRunner):
...
@@ -330,7 +331,7 @@ class WanAnimateRunner(WanRunner):
dtype
=
GET_DTYPE
(),
dtype
=
GET_DTYPE
(),
)
# c t h w
)
# c t h w
else
:
else
:
refer_t_pixel_values
=
self
.
gen_video
[
0
,
:,
-
self
.
config
[
"refert_num"
]
:].
transpose
(
0
,
1
).
clone
().
detach
().
cuda
(
)
# c t h w
refer_t_pixel_values
=
self
.
gen_video
[
0
,
:,
-
self
.
config
[
"refert_num"
]
:].
transpose
(
0
,
1
).
clone
().
detach
().
to
(
AI_DEVICE
)
# c t h w
bg_pixel_values
,
mask_pixel_values
=
None
,
None
bg_pixel_values
,
mask_pixel_values
=
None
,
None
if
self
.
config
[
"replace_flag"
]
if
"replace_flag"
in
self
.
config
else
False
:
if
self
.
config
[
"replace_flag"
]
if
"replace_flag"
in
self
.
config
else
False
:
...
@@ -408,8 +409,8 @@ class WanAnimateRunner(WanRunner):
...
@@ -408,8 +409,8 @@ class WanAnimateRunner(WanRunner):
return
model
return
model
def
load_encoders
(
self
):
def
load_encoders
(
self
):
motion_encoder
=
Generator
(
size
=
512
,
style_dim
=
512
,
motion_dim
=
20
).
eval
().
requires_grad_
(
False
).
to
(
GET_DTYPE
()).
cuda
(
)
motion_encoder
=
Generator
(
size
=
512
,
style_dim
=
512
,
motion_dim
=
20
).
eval
().
requires_grad_
(
False
).
to
(
GET_DTYPE
()).
to
(
AI_DEVICE
)
face_encoder
=
FaceEncoder
(
in_dim
=
512
,
hidden_dim
=
5120
,
num_heads
=
4
).
eval
().
requires_grad_
(
False
).
to
(
GET_DTYPE
()).
cuda
(
)
face_encoder
=
FaceEncoder
(
in_dim
=
512
,
hidden_dim
=
5120
,
num_heads
=
4
).
eval
().
requires_grad_
(
False
).
to
(
GET_DTYPE
()).
to
(
AI_DEVICE
)
motion_weight_dict
=
remove_substrings_from_keys
(
load_weights
(
self
.
config
[
"model_path"
],
include_keys
=
[
"motion_encoder"
]),
"motion_encoder."
)
motion_weight_dict
=
remove_substrings_from_keys
(
load_weights
(
self
.
config
[
"model_path"
],
include_keys
=
[
"motion_encoder"
]),
"motion_encoder."
)
face_weight_dict
=
remove_substrings_from_keys
(
load_weights
(
self
.
config
[
"model_path"
],
include_keys
=
[
"face_encoder"
]),
"face_encoder."
)
face_weight_dict
=
remove_substrings_from_keys
(
load_weights
(
self
.
config
[
"model_path"
],
include_keys
=
[
"face_encoder"
]),
"face_encoder."
)
motion_encoder
.
load_state_dict
(
motion_weight_dict
)
motion_encoder
.
load_state_dict
(
motion_weight_dict
)
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
74eeb429
...
@@ -435,7 +435,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -435,7 +435,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def
process_single_mask
(
self
,
mask_file
):
def
process_single_mask
(
self
,
mask_file
):
mask_img
=
load_image
(
mask_file
)
mask_img
=
load_image
(
mask_file
)
mask_img
=
TF
.
to_tensor
(
mask_img
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
cuda
(
)
mask_img
=
TF
.
to_tensor
(
mask_img
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
to
(
AI_DEVICE
)
if
mask_img
.
shape
[
1
]
==
3
:
# If it is an RGB three-channel image
if
mask_img
.
shape
[
1
]
==
3
:
# If it is an RGB three-channel image
mask_img
=
mask_img
[:,
:
1
]
# Only take the first channel
mask_img
=
mask_img
[:,
:
1
]
# Only take the first channel
...
...
lightx2v/models/runners/wan/wan_matrix_game2_runner.py
100644 → 100755
View file @
74eeb429
...
@@ -13,6 +13,7 @@ from lightx2v.server.metrics import monitor_cli
...
@@ -13,6 +13,7 @@ from lightx2v.server.metrics import monitor_cli
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
VAEWrapper
:
class
VAEWrapper
:
...
@@ -90,8 +91,8 @@ def get_current_action(mode="universal"):
...
@@ -90,8 +91,8 @@ def get_current_action(mode="universal"):
flag
=
1
flag
=
1
except
Exception
as
e
:
except
Exception
as
e
:
pass
pass
mouse_cond
=
torch
.
tensor
(
CAMERA_VALUE_MAP
[
idx_mouse
]).
cuda
(
)
mouse_cond
=
torch
.
tensor
(
CAMERA_VALUE_MAP
[
idx_mouse
]).
to
(
AI_DEVICE
)
keyboard_cond
=
torch
.
tensor
(
KEYBOARD_IDX
[
idx_keyboard
]).
cuda
(
)
keyboard_cond
=
torch
.
tensor
(
KEYBOARD_IDX
[
idx_keyboard
]).
to
(
AI_DEVICE
)
elif
mode
==
"gta_drive"
:
elif
mode
==
"gta_drive"
:
print
()
print
()
print
(
"-"
*
30
)
print
(
"-"
*
30
)
...
@@ -118,8 +119,8 @@ def get_current_action(mode="universal"):
...
@@ -118,8 +119,8 @@ def get_current_action(mode="universal"):
flag
=
1
flag
=
1
except
Exception
as
e
:
except
Exception
as
e
:
pass
pass
mouse_cond
=
torch
.
tensor
(
CAMERA_VALUE_MAP
[
idx_mouse
[
0
]]).
cuda
(
)
mouse_cond
=
torch
.
tensor
(
CAMERA_VALUE_MAP
[
idx_mouse
[
0
]]).
to
(
AI_DEVICE
)
keyboard_cond
=
torch
.
tensor
(
KEYBOARD_IDX
[
idx_keyboard
[
0
]]).
cuda
(
)
keyboard_cond
=
torch
.
tensor
(
KEYBOARD_IDX
[
idx_keyboard
[
0
]]).
to
(
AI_DEVICE
)
elif
mode
==
"templerun"
:
elif
mode
==
"templerun"
:
print
()
print
()
print
(
"-"
*
30
)
print
(
"-"
*
30
)
...
@@ -142,7 +143,7 @@ def get_current_action(mode="universal"):
...
@@ -142,7 +143,7 @@ def get_current_action(mode="universal"):
flag
=
1
flag
=
1
except
Exception
as
e
:
except
Exception
as
e
:
pass
pass
keyboard_cond
=
torch
.
tensor
(
KEYBOARD_IDX
[
idx_keyboard
]).
cuda
(
)
keyboard_cond
=
torch
.
tensor
(
KEYBOARD_IDX
[
idx_keyboard
]).
to
(
AI_DEVICE
)
if
mode
!=
"templerun"
:
if
mode
!=
"templerun"
:
return
{
"mouse"
:
mouse_cond
,
"keyboard"
:
keyboard_cond
}
return
{
"mouse"
:
mouse_cond
,
"keyboard"
:
keyboard_cond
}
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
74eeb429
...
@@ -164,7 +164,7 @@ class WanRunner(DefaultRunner):
...
@@ -164,7 +164,7 @@ class WanRunner(DefaultRunner):
if
vae_offload
:
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
vae_device
=
torch
.
device
(
"cpu"
)
else
:
else
:
vae_device
=
torch
.
device
(
self
.
init_device
)
vae_device
=
torch
.
device
(
AI_DEVICE
)
vae_config
=
{
vae_config
=
{
"vae_path"
:
find_torch_model_path
(
self
.
config
,
"vae_path"
,
self
.
vae_name
),
"vae_path"
:
find_torch_model_path
(
self
.
config
,
"vae_path"
,
self
.
vae_name
),
...
@@ -178,7 +178,7 @@ class WanRunner(DefaultRunner):
...
@@ -178,7 +178,7 @@ class WanRunner(DefaultRunner):
}
}
if
self
.
config
.
get
(
"use_tae"
,
False
):
if
self
.
config
.
get
(
"use_tae"
,
False
):
tae_path
=
find_torch_model_path
(
self
.
config
,
"tae_path"
,
self
.
tiny_vae_name
)
tae_path
=
find_torch_model_path
(
self
.
config
,
"tae_path"
,
self
.
tiny_vae_name
)
vae_decoder
=
self
.
tiny_vae_cls
(
vae_path
=
tae_path
,
device
=
self
.
init_device
,
need_scaled
=
self
.
config
.
get
(
"need_scaled"
,
False
)).
to
(
"cuda"
)
vae_decoder
=
self
.
tiny_vae_cls
(
vae_path
=
tae_path
,
device
=
self
.
init_device
,
need_scaled
=
self
.
config
.
get
(
"need_scaled"
,
False
)).
to
(
AI_DEVICE
)
else
:
else
:
vae_decoder
=
self
.
vae_cls
(
**
vae_config
)
vae_decoder
=
self
.
vae_cls
(
**
vae_config
)
return
vae_decoder
return
vae_decoder
...
...
lightx2v/models/schedulers/hunyuan_video/posemb_layers.py
100644 → 100755
View file @
74eeb429
...
@@ -2,6 +2,8 @@ from typing import List, Tuple, Union
...
@@ -2,6 +2,8 @@ from typing import List, Tuple, Union
import
torch
import
torch
from
lightx2v_platform.base.global_var
import
AI_DEVICE
def
_to_tuple
(
x
,
dim
=
2
):
def
_to_tuple
(
x
,
dim
=
2
):
if
isinstance
(
x
,
int
):
if
isinstance
(
x
,
int
):
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment