Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
51be3ad2
Unverified
Commit
51be3ad2
authored
Nov 14, 2025
by
gushiqiao
Committed by
GitHub
Nov 14, 2025
Browse files
[Fix] remove d2h of cpu-offload infer (#476)
parent
2559b3e7
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
387 additions
and
158 deletions
+387
-158
lightx2v/models/networks/qwen_image/weights/post_weights.py
lightx2v/models/networks/qwen_image/weights/post_weights.py
+2
-7
lightx2v/models/networks/qwen_image/weights/transformer_weights.py
...models/networks/qwen_image/weights/transformer_weights.py
+50
-60
lightx2v/models/networks/wan/infer/animate/transformer_infer.py
...2v/models/networks/wan/infer/animate/transformer_infer.py
+37
-1
lightx2v/models/networks/wan/infer/offload/transformer_infer.py
...2v/models/networks/wan/infer/offload/transformer_infer.py
+46
-62
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+1
-1
lightx2v/models/networks/wan/infer/vace/transformer_infer.py
lightx2v/models/networks/wan/infer/vace/transformer_infer.py
+4
-4
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+2
-0
lightx2v/models/networks/wan/vace_model.py
lightx2v/models/networks/wan/vace_model.py
+8
-0
lightx2v/models/networks/wan/weights/animate/transformer_weights.py
...odels/networks/wan/weights/animate/transformer_weights.py
+49
-6
lightx2v/models/networks/wan/weights/audio/transformer_weights.py
.../models/networks/wan/weights/audio/transformer_weights.py
+38
-1
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+109
-6
lightx2v/models/networks/wan/weights/vace/transformer_weights.py
...v/models/networks/wan/weights/vace/transformer_weights.py
+21
-4
lightx2v/models/runners/wan/wan_distill_runner.py
lightx2v/models/runners/wan/wan_distill_runner.py
+12
-6
requirements_animate.txt
requirements_animate.txt
+8
-0
No files found.
lightx2v/models/networks/qwen_image/weights/post_weights.py
View file @
51be3ad2
...
@@ -10,11 +10,6 @@ class QwenImagePostWeights(WeightModule):
...
@@ -10,11 +10,6 @@ class QwenImagePostWeights(WeightModule):
super
().
__init__
()
super
().
__init__
()
self
.
task
=
config
[
"task"
]
self
.
task
=
config
[
"task"
]
self
.
config
=
config
self
.
config
=
config
if
config
[
"do_mm_calib"
]:
self
.
mm_type
=
"Calib"
else
:
self
.
mm_type
=
config
[
"mm_config"
].
get
(
"mm_type"
,
"Default"
)
if
config
[
"mm_config"
]
else
"Default"
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
if
self
.
lazy_load
:
if
self
.
lazy_load
:
assert
NotImplementedError
assert
NotImplementedError
...
@@ -23,7 +18,7 @@ class QwenImagePostWeights(WeightModule):
...
@@ -23,7 +18,7 @@ class QwenImagePostWeights(WeightModule):
# norm_out
# norm_out
self
.
add_module
(
self
.
add_module
(
"norm_out_linear"
,
"norm_out_linear"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
"Default"
](
"norm_out.linear.weight"
,
"norm_out.linear.weight"
,
"norm_out.linear.bias"
,
"norm_out.linear.bias"
,
self
.
lazy_load
,
self
.
lazy_load
,
...
@@ -35,7 +30,7 @@ class QwenImagePostWeights(WeightModule):
...
@@ -35,7 +30,7 @@ class QwenImagePostWeights(WeightModule):
# proj_out
# proj_out
self
.
add_module
(
self
.
add_module
(
"proj_out_linear"
,
"proj_out_linear"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
"Default"
](
"proj_out.weight"
,
"proj_out.weight"
,
"proj_out.bias"
,
"proj_out.bias"
,
self
.
lazy_load
,
self
.
lazy_load
,
...
...
lightx2v/models/networks/qwen_image/weights/transformer_weights.py
View file @
51be3ad2
...
@@ -12,17 +12,28 @@ class QwenImageTransformerWeights(WeightModule):
...
@@ -12,17 +12,28 @@ class QwenImageTransformerWeights(WeightModule):
self
.
blocks_num
=
config
[
"num_layers"
]
self
.
blocks_num
=
config
[
"num_layers"
]
self
.
task
=
config
[
"task"
]
self
.
task
=
config
[
"task"
]
self
.
config
=
config
self
.
config
=
config
if
config
[
"do_mm_calib"
]:
self
.
mm_type
=
config
.
get
(
"dit_quant_scheme"
,
"Default"
)
self
.
mm_type
=
"Calib"
if
self
.
mm_type
!=
"Default"
:
else
:
assert
config
.
get
(
"dit_quantized"
)
is
True
self
.
mm_type
=
config
[
"mm_config"
].
get
(
"mm_type"
,
"Default"
)
if
config
[
"mm_config"
]
else
"Default"
blocks
=
WeightModuleList
(
QwenImageTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
False
,
"transformer_blocks"
)
for
i
in
range
(
self
.
blocks_num
))
self
.
register_offload_buffers
(
config
)
blocks
=
WeightModuleList
(
QwenImageTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
"transformer_blocks"
)
for
i
in
range
(
self
.
blocks_num
))
self
.
add_module
(
"blocks"
,
blocks
)
self
.
add_module
(
"blocks"
,
blocks
)
def
register_offload_buffers
(
self
,
config
):
if
config
[
"cpu_offload"
]:
if
config
[
"offload_granularity"
]
==
"block"
:
self
.
offload_blocks_num
=
2
self
.
offload_block_buffers
=
WeightModuleList
(
[
QwenImageTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
True
,
"transformer_blocks"
)
for
i
in
range
(
self
.
offload_blocks_num
)]
)
self
.
add_module
(
"offload_block_buffers"
,
self
.
offload_block_buffers
)
self
.
offload_phase_buffers
=
None
else
:
raise
NotImplementedError
class
QwenImageTransformerAttentionBlock
(
WeightModule
):
class
QwenImageTransformerAttentionBlock
(
WeightModule
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
,
block_prefix
=
"transformer_blocks"
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
,
is_offload_buffer
=
False
,
block_prefix
=
"transformer_blocks"
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
mm_type
=
mm_type
...
@@ -44,22 +55,30 @@ class QwenImageTransformerAttentionBlock(WeightModule):
...
@@ -44,22 +55,30 @@ class QwenImageTransformerAttentionBlock(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mod.1.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mod.1.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mod.1.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mod.1.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
)
)
self
.
add_module
(
self
.
add_module
(
"img_norm1"
,
"img_norm1"
,
LN_WEIGHT_REGISTER
[
"Default"
](
eps
=
1e-6
),
LN_WEIGHT_REGISTER
[
"Default"
](
create_cuda_buffer
=
is_offload_buffer
,
eps
=
1e-6
),
)
)
self
.
attn
=
QwenImageCrossAttention
(
self
.
attn
=
QwenImageCrossAttention
(
block_index
=
block_index
,
block_prefix
=
"transformer_blocks"
,
task
=
config
[
"task"
],
mm_type
=
mm_type
,
config
=
config
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
block_index
=
block_index
,
block_prefix
=
"transformer_blocks"
,
task
=
config
[
"task"
],
mm_type
=
mm_type
,
config
=
config
,
is_offload_buffer
=
is_offload_buffer
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
)
)
self
.
add_module
(
"attn"
,
self
.
attn
)
self
.
add_module
(
"attn"
,
self
.
attn
)
self
.
add_module
(
self
.
add_module
(
"img_norm2"
,
"img_norm2"
,
LN_WEIGHT_REGISTER
[
"Default"
](
eps
=
1e-6
),
LN_WEIGHT_REGISTER
[
"Default"
](
create_cuda_buffer
=
is_offload_buffer
,
eps
=
1e-6
),
)
)
img_mlp
=
QwenImageFFN
(
img_mlp
=
QwenImageFFN
(
block_index
=
block_index
,
block_index
=
block_index
,
...
@@ -68,6 +87,7 @@ class QwenImageTransformerAttentionBlock(WeightModule):
...
@@ -68,6 +87,7 @@ class QwenImageTransformerAttentionBlock(WeightModule):
task
=
config
[
"task"
],
task
=
config
[
"task"
],
mm_type
=
mm_type
,
mm_type
=
mm_type
,
config
=
config
,
config
=
config
,
is_offload_buffer
=
is_offload_buffer
,
lazy_load
=
self
.
lazy_load
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
lazy_load_file
=
self
.
lazy_load_file
,
)
)
...
@@ -79,19 +99,20 @@ class QwenImageTransformerAttentionBlock(WeightModule):
...
@@ -79,19 +99,20 @@ class QwenImageTransformerAttentionBlock(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mod.1.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mod.1.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mod.1.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mod.1.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
)
)
self
.
add_module
(
self
.
add_module
(
"txt_norm1"
,
"txt_norm1"
,
LN_WEIGHT_REGISTER
[
"Default"
](
eps
=
1e-6
),
LN_WEIGHT_REGISTER
[
"Default"
](
create_cuda_buffer
=
is_offload_buffer
,
eps
=
1e-6
),
)
)
# Text doesn't need separate attention - it's handled by img_attn joint computation
# Text doesn't need separate attention - it's handled by img_attn joint computation
self
.
add_module
(
self
.
add_module
(
"txt_norm2"
,
"txt_norm2"
,
LN_WEIGHT_REGISTER
[
"Default"
](
eps
=
1e-6
),
LN_WEIGHT_REGISTER
[
"Default"
](
create_cuda_buffer
=
is_offload_buffer
,
eps
=
1e-6
),
)
)
txt_mlp
=
QwenImageFFN
(
txt_mlp
=
QwenImageFFN
(
block_index
=
block_index
,
block_index
=
block_index
,
...
@@ -100,39 +121,15 @@ class QwenImageTransformerAttentionBlock(WeightModule):
...
@@ -100,39 +121,15 @@ class QwenImageTransformerAttentionBlock(WeightModule):
task
=
config
[
"task"
],
task
=
config
[
"task"
],
mm_type
=
mm_type
,
mm_type
=
mm_type
,
config
=
config
,
config
=
config
,
is_offload_buffer
=
is_offload_buffer
,
lazy_load
=
self
.
lazy_load
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
lazy_load_file
=
self
.
lazy_load_file
,
)
)
self
.
add_module
(
"txt_mlp"
,
txt_mlp
)
self
.
add_module
(
"txt_mlp"
,
txt_mlp
)
self
.
cpu_offload
=
config
[
"cpu_offload"
]
if
self
.
cpu_offload
:
offload_granularity
=
config
.
get
(
"offload_granularity"
,
"block"
)
if
offload_granularity
==
"phase"
:
phase1_dict
=
{
"img_mod"
:
self
.
img_mod
,
"txt_mod"
:
self
.
txt_mod
,
"img_norm1"
:
self
.
img_norm1
,
"txt_norm1"
:
self
.
txt_norm1
,
}
phase2_dict
=
{
"attn"
:
self
.
attn
}
phase3_dict
=
{
"img_norm2"
:
self
.
img_norm2
,
"img_mlp"
:
self
.
img_mlp
,
"txt_norm2"
:
self
.
txt_norm2
,
"txt_mlp"
:
self
.
txt_mlp
,
}
compute_phases
=
[
ComputePhase
(
phase1_dict
),
ComputePhase
(
phase2_dict
),
ComputePhase
(
phase3_dict
),
]
self
.
add_module
(
"compute_phases"
,
compute_phases
)
class
QwenImageCrossAttention
(
WeightModule
):
class
QwenImageCrossAttention
(
WeightModule
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
lazy_load
,
lazy_load_file
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
is_offload_buffer
,
lazy_load
,
lazy_load_file
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
mm_type
=
mm_type
...
@@ -149,12 +146,12 @@ class QwenImageCrossAttention(WeightModule):
...
@@ -149,12 +146,12 @@ class QwenImageCrossAttention(WeightModule):
# norm_q
# norm_q
self
.
add_module
(
self
.
add_module
(
"norm_q"
,
"norm_q"
,
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_q.weight"
),
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_q.weight"
,
create_cuda_buffer
=
is_offload_buffer
),
)
)
# norm_k
# norm_k
self
.
add_module
(
self
.
add_module
(
"norm_k"
,
"norm_k"
,
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_k.weight"
),
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_k.weight"
,
create_cuda_buffer
=
is_offload_buffer
),
)
)
# to_q
# to_q
self
.
add_module
(
self
.
add_module
(
...
@@ -162,6 +159,7 @@ class QwenImageCrossAttention(WeightModule):
...
@@ -162,6 +159,7 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_q.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_q.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -172,6 +170,7 @@ class QwenImageCrossAttention(WeightModule):
...
@@ -172,6 +170,7 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_k.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_k.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -182,6 +181,7 @@ class QwenImageCrossAttention(WeightModule):
...
@@ -182,6 +181,7 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_v.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_v.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -192,6 +192,7 @@ class QwenImageCrossAttention(WeightModule):
...
@@ -192,6 +192,7 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_q_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_q_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_q_proj.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_q_proj.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -202,6 +203,7 @@ class QwenImageCrossAttention(WeightModule):
...
@@ -202,6 +203,7 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_k_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_k_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_k_proj.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_k_proj.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -212,6 +214,7 @@ class QwenImageCrossAttention(WeightModule):
...
@@ -212,6 +214,7 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_v_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_v_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_v_proj.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_v_proj.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -222,6 +225,7 @@ class QwenImageCrossAttention(WeightModule):
...
@@ -222,6 +225,7 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_out.0.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_out.0.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_out.0.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_out.0.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -232,6 +236,7 @@ class QwenImageCrossAttention(WeightModule):
...
@@ -232,6 +236,7 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_add_out.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_add_out.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_add_out.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_add_out.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -239,12 +244,12 @@ class QwenImageCrossAttention(WeightModule):
...
@@ -239,12 +244,12 @@ class QwenImageCrossAttention(WeightModule):
# norm_added_q
# norm_added_q
self
.
add_module
(
self
.
add_module
(
"norm_added_q"
,
"norm_added_q"
,
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_added_q.weight"
),
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_added_q.weight"
,
create_cuda_buffer
=
is_offload_buffer
),
)
)
# norm_added_k
# norm_added_k
self
.
add_module
(
self
.
add_module
(
"norm_added_k"
,
"norm_added_k"
,
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_added_k.weight"
),
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_added_k.weight"
,
create_cuda_buffer
=
is_offload_buffer
),
)
)
# attn
# attn
self
.
add_module
(
"calculate"
,
ATTN_WEIGHT_REGISTER
[
self
.
attn_type
]())
self
.
add_module
(
"calculate"
,
ATTN_WEIGHT_REGISTER
[
self
.
attn_type
]())
...
@@ -261,7 +266,7 @@ class QwenImageCrossAttention(WeightModule):
...
@@ -261,7 +266,7 @@ class QwenImageCrossAttention(WeightModule):
class
QwenImageFFN
(
WeightModule
):
class
QwenImageFFN
(
WeightModule
):
def
__init__
(
self
,
block_index
,
block_prefix
,
ffn_prefix
,
task
,
mm_type
,
config
,
lazy_load
,
lazy_load_file
):
def
__init__
(
self
,
block_index
,
block_prefix
,
ffn_prefix
,
task
,
mm_type
,
config
,
is_offload_buffer
,
lazy_load
,
lazy_load_file
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
mm_type
=
mm_type
...
@@ -276,6 +281,7 @@ class QwenImageFFN(WeightModule):
...
@@ -276,6 +281,7 @@ class QwenImageFFN(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.
{
ffn_prefix
}
.net.0.proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.
{
ffn_prefix
}
.net.0.proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.
{
ffn_prefix
}
.net.0.proj.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.
{
ffn_prefix
}
.net.0.proj.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -285,6 +291,7 @@ class QwenImageFFN(WeightModule):
...
@@ -285,6 +291,7 @@ class QwenImageFFN(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.
{
ffn_prefix
}
.net.2.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.
{
ffn_prefix
}
.net.2.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.
{
ffn_prefix
}
.net.2.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.
{
ffn_prefix
}
.net.2.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -299,20 +306,3 @@ class QwenImageFFN(WeightModule):
...
@@ -299,20 +306,3 @@ class QwenImageFFN(WeightModule):
for
module
in
self
.
_modules
.
values
():
for
module
in
self
.
_modules
.
values
():
if
module
is
not
None
and
hasattr
(
module
,
"to_cuda"
):
if
module
is
not
None
and
hasattr
(
module
,
"to_cuda"
):
module
.
to_cuda
(
non_blocking
=
non_blocking
)
module
.
to_cuda
(
non_blocking
=
non_blocking
)
class
ComputePhase
(
WeightModule
):
def
__init__
(
self
,
sub_module_dict
):
super
().
__init__
()
for
k
,
v
in
sub_module_dict
.
items
():
self
.
add_module
(
k
,
v
)
def
to_cpu
(
self
,
non_blocking
=
True
):
for
module
in
self
.
_modules
.
values
():
if
module
is
not
None
and
hasattr
(
module
,
"to_cpu"
):
module
.
to_cpu
(
non_blocking
=
non_blocking
)
def
to_cuda
(
self
,
non_blocking
=
True
):
for
module
in
self
.
_modules
.
values
():
if
module
is
not
None
and
hasattr
(
module
,
"to_cuda"
):
module
.
to_cuda
(
non_blocking
=
non_blocking
)
lightx2v/models/networks/wan/infer/animate/transformer_infer.py
View file @
51be3ad2
...
@@ -10,9 +10,45 @@ class WanAnimateTransformerInfer(WanOffloadTransformerInfer):
...
@@ -10,9 +10,45 @@ class WanAnimateTransformerInfer(WanOffloadTransformerInfer):
self
.
has_post_adapter
=
True
self
.
has_post_adapter
=
True
self
.
phases_num
=
4
self
.
phases_num
=
4
def
infer_with_blocks_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
for
block_idx
in
range
(
len
(
blocks
)):
self
.
block_idx
=
block_idx
if
block_idx
==
0
:
self
.
offload_manager
.
init_first_buffer
(
blocks
,
block_idx
//
5
)
if
block_idx
<
len
(
blocks
)
-
1
:
self
.
offload_manager
.
prefetch_weights
(
block_idx
+
1
,
blocks
,
(
block_idx
+
1
)
//
5
)
with
torch
.
cuda
.
stream
(
self
.
offload_manager
.
compute_stream
):
x
=
self
.
infer_block
(
self
.
offload_manager
.
cuda_buffers
[
0
],
x
,
pre_infer_out
)
self
.
offload_manager
.
swap_blocks
()
return
x
def
infer_phases
(
self
,
block_idx
,
blocks
,
x
,
pre_infer_out
,
lazy
):
for
phase_idx
in
range
(
self
.
phases_num
):
if
block_idx
==
0
and
phase_idx
==
0
:
if
lazy
:
obj_key
=
(
block_idx
,
phase_idx
)
phase
=
self
.
offload_manager
.
pin_memory_buffer
.
get
(
obj_key
)
phase
.
to_cuda
()
self
.
offload_manager
.
cuda_buffers
[
0
]
=
(
obj_key
,
phase
)
else
:
self
.
offload_manager
.
init_first_buffer
(
blocks
,
block_idx
//
5
)
is_last_phase
=
block_idx
==
len
(
blocks
)
-
1
and
phase_idx
==
self
.
phases_num
-
1
if
not
is_last_phase
:
next_block_idx
=
block_idx
+
1
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
next_phase_idx
=
(
phase_idx
+
1
)
%
self
.
phases_num
self
.
offload_manager
.
prefetch_phase
(
next_block_idx
,
next_phase_idx
,
blocks
,
(
block_idx
+
1
)
//
5
)
with
torch
.
cuda
.
stream
(
self
.
offload_manager
.
compute_stream
):
x
=
self
.
infer_phase
(
phase_idx
,
self
.
offload_manager
.
cuda_buffers
[
phase_idx
],
x
,
pre_infer_out
)
self
.
offload_manager
.
swap_phases
()
return
x
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer_post_adapter
(
self
,
phase
,
x
,
pre_infer_out
):
def
infer_post_adapter
(
self
,
phase
,
x
,
pre_infer_out
):
if
phase
.
is_empty
():
if
phase
.
is_empty
()
or
phase
.
linear1_kv
.
weight
is
None
:
return
x
return
x
T
=
pre_infer_out
.
adapter_args
[
"motion_vec"
].
shape
[
0
]
T
=
pre_infer_out
.
adapter_args
[
"motion_vec"
].
shape
[
0
]
x_motion
=
phase
.
pre_norm_motion
.
apply
(
pre_infer_out
.
adapter_args
[
"motion_vec"
])
x_motion
=
phase
.
pre_norm_motion
.
apply
(
pre_infer_out
.
adapter_args
[
"motion_vec"
])
...
...
lightx2v/models/networks/wan/infer/offload/transformer_infer.py
View file @
51be3ad2
...
@@ -42,13 +42,9 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
...
@@ -42,13 +42,9 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
if
offload_granularity
!=
"model"
:
if
offload_granularity
!=
"model"
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
weights_stream_mgr
=
WeightAsyncStreamManager
(
self
.
offload_manager
=
WeightAsyncStreamManager
(
offload_granularity
=
offload_granularity
)
blocks_num
=
self
.
blocks_num
,
offload_ratio
=
self
.
offload_ratio
,
phases_num
=
self
.
phases_num
,
)
else
:
else
:
self
.
weights_stream_mg
r
=
LazyWeightAsyncStreamManager
(
self
.
offload_manage
r
=
LazyWeightAsyncStreamManager
(
blocks_num
=
self
.
blocks_num
,
blocks_num
=
self
.
blocks_num
,
offload_ratio
=
self
.
offload_ratio
,
offload_ratio
=
self
.
offload_ratio
,
phases_num
=
self
.
phases_num
,
phases_num
=
self
.
phases_num
,
...
@@ -61,40 +57,57 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
...
@@ -61,40 +57,57 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
for
block_idx
in
range
(
len
(
blocks
)):
for
block_idx
in
range
(
len
(
blocks
)):
self
.
block_idx
=
block_idx
self
.
block_idx
=
block_idx
if
block_idx
==
0
:
if
block_idx
==
0
:
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
blocks
[
0
]
self
.
offload_manager
.
init_first_buffer
(
blocks
)
self
.
weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
if
block_idx
<
len
(
blocks
)
-
1
:
if
block_idx
<
len
(
blocks
)
-
1
:
self
.
weights_stream_mg
r
.
prefetch_weights
(
block_idx
+
1
,
blocks
)
self
.
offload_manage
r
.
prefetch_weights
(
block_idx
+
1
,
blocks
)
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
with
torch
.
cuda
.
stream
(
self
.
offload_manager
.
compute_stream
):
x
=
self
.
infer_block
(
blocks
[
block_idx
],
x
,
pre_infer_out
)
x
=
self
.
infer_block
(
self
.
offload_manager
.
cuda_buffers
[
0
],
x
,
pre_infer_out
)
self
.
weights_stream_mgr
.
swap_weights
()
self
.
offload_manager
.
swap_blocks
()
return
x
def
infer_with_phases_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
for
block_idx
in
range
(
len
(
blocks
)):
self
.
block_idx
=
block_idx
x
=
self
.
infer_phases
(
block_idx
,
blocks
,
x
,
pre_infer_out
,
False
)
if
self
.
clean_cuda_cache
:
del
(
self
.
phase_params
[
"attn_out"
],
self
.
phase_params
[
"y_out"
],
self
.
phase_params
[
"y"
],
)
torch
.
cuda
.
empty_cache
()
if
self
.
clean_cuda_cache
:
self
.
clear_offload_params
(
pre_infer_out
)
return
x
return
x
def
infer_with_blocks_lazy_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
def
infer_with_blocks_lazy_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
self
.
weights_stream_mg
r
.
prefetch_weights_from_disk
(
blocks
)
self
.
offload_manage
r
.
prefetch_weights_from_disk
(
blocks
)
for
block_idx
in
range
(
len
(
blocks
)):
for
block_idx
in
range
(
len
(
blocks
)):
self
.
block_idx
=
block_idx
self
.
block_idx
=
block_idx
if
block_idx
==
0
:
if
block_idx
==
0
:
block
=
self
.
weights_stream_mg
r
.
pin_memory_buffer
.
get
(
block_idx
)
block
=
self
.
offload_manage
r
.
pin_memory_buffer
.
get
(
block_idx
)
block
.
to_cuda
()
block
.
to_cuda
()
self
.
weights_stream_mgr
.
active_weight
s
[
0
]
=
(
block_idx
,
block
)
self
.
offload_manager
.
cuda_buffer
s
[
0
]
=
(
block_idx
,
block
)
if
block_idx
<
len
(
blocks
)
-
1
:
if
block_idx
<
len
(
blocks
)
-
1
:
self
.
weights_stream_mg
r
.
prefetch_weights
(
block_idx
+
1
,
blocks
)
self
.
offload_manage
r
.
prefetch_weights
(
block_idx
+
1
,
blocks
)
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mg
r
.
compute_stream
):
with
torch
.
cuda
.
stream
(
self
.
offload_manage
r
.
compute_stream
):
x
=
self
.
infer_block
(
blocks
[
block_idx
],
x
,
pre_infer_out
)
x
=
self
.
infer_block
(
blocks
[
block_idx
],
x
,
pre_infer_out
)
self
.
weights_stream_mg
r
.
swap_
weight
s
()
self
.
offload_manage
r
.
swap_
block
s
()
if
block_idx
==
len
(
blocks
)
-
1
:
if
block_idx
==
len
(
blocks
)
-
1
:
self
.
weights_stream_mg
r
.
pin_memory_buffer
.
pop_front
()
self
.
offload_manage
r
.
pin_memory_buffer
.
pop_front
()
self
.
weights_stream_mg
r
.
_async_prefetch_block
(
blocks
)
self
.
offload_manage
r
.
_async_prefetch_block
(
blocks
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
(
del
(
...
@@ -106,31 +119,14 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
...
@@ -106,31 +119,14 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
return
x
return
x
def
infer_with_phases_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
for
block_idx
in
range
(
len
(
blocks
)):
self
.
block_idx
=
block_idx
x
=
self
.
infer_phases
(
block_idx
,
blocks
,
x
,
pre_infer_out
,
False
)
if
self
.
clean_cuda_cache
:
del
(
self
.
phase_params
[
"attn_out"
],
self
.
phase_params
[
"y_out"
],
self
.
phase_params
[
"y"
],
)
torch
.
cuda
.
empty_cache
()
if
self
.
clean_cuda_cache
:
self
.
clear_offload_params
(
pre_infer_out
)
return
x
def
infer_with_phases_lazy_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
def
infer_with_phases_lazy_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
self
.
weights_stream_mg
r
.
prefetch_weights_from_disk
(
blocks
)
self
.
offload_manage
r
.
prefetch_weights_from_disk
(
blocks
)
for
block_idx
in
range
(
len
(
blocks
)):
for
block_idx
in
range
(
len
(
blocks
)):
self
.
block_idx
=
block_idx
self
.
block_idx
=
block_idx
x
=
self
.
infer_phases
(
block_idx
,
blocks
,
x
,
pre_infer_out
,
True
)
x
=
self
.
infer_phases
(
block_idx
,
blocks
,
x
,
pre_infer_out
,
True
)
self
.
weights_stream_mg
r
.
_async_prefetch_block
(
blocks
)
self
.
offload_manage
r
.
_async_prefetch_block
(
blocks
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
(
del
(
...
@@ -148,35 +144,27 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
...
@@ -148,35 +144,27 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
if
block_idx
==
0
and
phase_idx
==
0
:
if
block_idx
==
0
and
phase_idx
==
0
:
if
lazy
:
if
lazy
:
obj_key
=
(
block_idx
,
phase_idx
)
obj_key
=
(
block_idx
,
phase_idx
)
phase
=
self
.
weights_stream_mg
r
.
pin_memory_buffer
.
get
(
obj_key
)
phase
=
self
.
offload_manage
r
.
pin_memory_buffer
.
get
(
obj_key
)
phase
.
to_cuda
()
phase
.
to_cuda
()
self
.
weights_stream_mgr
.
active_weight
s
[
0
]
=
(
obj_key
,
phase
)
self
.
offload_manager
.
cuda_buffer
s
[
0
]
=
(
obj_key
,
phase
)
else
:
else
:
phase
=
blocks
[
block_idx
].
compute_phases
[
phase_idx
]
self
.
offload_manager
.
init_first_buffer
(
blocks
)
phase
.
to_cuda
()
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
(
phase_idx
,
phase
)
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
x
=
self
.
infer_phase
(
self
.
weights_stream_mgr
.
active_weights
[
0
],
x
,
pre_infer_out
)
is_last_phase
=
block_idx
==
len
(
blocks
)
-
1
and
phase_idx
==
self
.
phases_num
-
1
is_last_phase
=
block_idx
==
len
(
blocks
)
-
1
and
phase_idx
==
self
.
phases_num
-
1
if
not
is_last_phase
:
if
not
is_last_phase
:
next_block_idx
=
block_idx
+
1
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
next_block_idx
=
block_idx
+
1
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
next_phase_idx
=
(
phase_idx
+
1
)
%
self
.
phases_num
next_phase_idx
=
(
phase_idx
+
1
)
%
self
.
phases_num
self
.
weights_stream_mg
r
.
prefetch_phase
(
next_block_idx
,
next_phase_idx
,
blocks
)
self
.
offload_manage
r
.
prefetch_phase
(
next_block_idx
,
next_phase_idx
,
blocks
)
self
.
weights_stream_mgr
.
swap_phases
()
with
torch
.
cuda
.
stream
(
self
.
offload_manager
.
compute_stream
):
x
=
self
.
infer_phase
(
phase_idx
,
self
.
offload_manager
.
cuda_buffers
[
phase_idx
],
x
,
pre_infer_out
)
return
x
self
.
offload_manager
.
swap_phases
()
def
infer_phase
(
self
,
active_weight
,
x
,
pre_infer_out
):
return
x
if
not
self
.
config
.
get
(
"lazy_load"
):
cur_phase_idx
,
cur_phase
=
active_weight
else
:
(
_
,
cur_phase_idx
),
cur_phase
=
active_weight
def
infer_phase
(
self
,
cur_phase_idx
,
cur_phase
,
x
,
pre_infer_out
):
if
cur_phase_idx
==
0
:
if
cur_phase_idx
==
0
:
if
hasattr
(
cur_phase
,
"before_proj"
):
if
hasattr
(
cur_phase
,
"before_proj"
)
and
cur_phase
.
before_proj
.
weight
is
not
None
:
x
=
cur_phase
.
before_proj
.
apply
(
x
)
+
pre_infer_out
.
x
x
=
cur_phase
.
before_proj
.
apply
(
x
)
+
pre_infer_out
.
x
(
(
self
.
phase_params
[
"shift_msa"
],
self
.
phase_params
[
"shift_msa"
],
...
@@ -211,11 +199,7 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
...
@@ -211,11 +199,7 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
self
.
phase_params
[
"c_shift_msa"
],
self
.
phase_params
[
"c_shift_msa"
],
self
.
phase_params
[
"c_scale_msa"
],
self
.
phase_params
[
"c_scale_msa"
],
)
)
x
=
self
.
post_process
(
x
=
self
.
post_process
(
x
,
self
.
phase_params
[
"y"
],
self
.
phase_params
[
"c_gate_msa"
],
pre_infer_out
)
x
,
self
.
phase_params
[
"y"
],
self
.
phase_params
[
"c_gate_msa"
],
)
if
hasattr
(
cur_phase
,
"after_proj"
):
if
hasattr
(
cur_phase
,
"after_proj"
):
pre_infer_out
.
adapter_args
[
"hints"
].
append
(
cur_phase
.
after_proj
.
apply
(
x
))
pre_infer_out
.
adapter_args
[
"hints"
].
append
(
cur_phase
.
after_proj
.
apply
(
x
))
elif
cur_phase_idx
==
3
:
elif
cur_phase_idx
==
3
:
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
51be3ad2
...
@@ -88,7 +88,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -88,7 +88,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return
x
return
x
def
infer_block
(
self
,
block
,
x
,
pre_infer_out
):
def
infer_block
(
self
,
block
,
x
,
pre_infer_out
):
if
hasattr
(
block
.
compute_phases
[
0
],
"before_proj"
):
if
hasattr
(
block
.
compute_phases
[
0
],
"before_proj"
)
and
block
.
compute_phases
[
0
].
before_proj
.
weight
is
not
None
:
x
=
block
.
compute_phases
[
0
].
before_proj
.
apply
(
x
)
+
pre_infer_out
.
x
x
=
block
.
compute_phases
[
0
].
before_proj
.
apply
(
x
)
+
pre_infer_out
.
x
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
pre_process
(
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
pre_process
(
...
...
lightx2v/models/networks/wan/infer/vace/transformer_infer.py
View file @
51be3ad2
...
@@ -22,12 +22,12 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer):
...
@@ -22,12 +22,12 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer):
def
infer_vace_blocks
(
self
,
vace_blocks
,
pre_infer_out
):
def
infer_vace_blocks
(
self
,
vace_blocks
,
pre_infer_out
):
pre_infer_out
.
adapter_args
[
"hints"
]
=
[]
pre_infer_out
.
adapter_args
[
"hints"
]
=
[]
self
.
infer_state
=
"vace"
self
.
infer_state
=
"vace"
if
hasattr
(
self
,
"
weights_stream_mg
r"
):
if
hasattr
(
self
,
"
offload_manage
r"
):
self
.
weights_stream_mgr
.
init
(
self
.
vace_blocks_num
,
self
.
phases_num
,
self
.
offload_
ratio
)
self
.
offload_manager
.
init_cuda_buffer
(
self
.
vace_offload_block_buffers
,
self
.
vace_
offload_
phase_buffers
)
self
.
infer_func
(
vace_blocks
,
pre_infer_out
.
c
,
pre_infer_out
)
self
.
infer_func
(
vace_blocks
,
pre_infer_out
.
c
,
pre_infer_out
)
self
.
infer_state
=
"base"
self
.
infer_state
=
"base"
if
hasattr
(
self
,
"
weights_stream_mg
r"
):
if
hasattr
(
self
,
"
offload_manage
r"
):
self
.
weights_stream_mgr
.
init
(
self
.
blocks_num
,
self
.
phases_num
,
self
.
offload_
ratio
)
self
.
offload_manager
.
init_cuda_buffer
(
self
.
offload_block_buffers
,
self
.
offload_
phase_buffers
)
def
post_process
(
self
,
x
,
y
,
c_gate_msa
,
pre_infer_out
):
def
post_process
(
self
,
x
,
y
,
c_gate_msa
,
pre_infer_out
):
x
=
super
().
post_process
(
x
,
y
,
c_gate_msa
,
pre_infer_out
)
x
=
super
().
post_process
(
x
,
y
,
c_gate_msa
,
pre_infer_out
)
...
...
lightx2v/models/networks/wan/model.py
View file @
51be3ad2
...
@@ -365,6 +365,8 @@ class WanModel(CompiledMethodsMixin):
...
@@ -365,6 +365,8 @@ class WanModel(CompiledMethodsMixin):
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
if
hasattr
(
self
.
transformer_infer
,
"offload_manager"
):
self
.
transformer_infer
.
offload_manager
.
init_cuda_buffer
(
self
.
transformer_weights
.
offload_block_buffers
,
self
.
transformer_weights
.
offload_phase_buffers
)
def
set_scheduler
(
self
,
scheduler
):
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
...
...
lightx2v/models/networks/wan/vace_model.py
View file @
51be3ad2
...
@@ -19,6 +19,14 @@ class WanVaceModel(WanModel):
...
@@ -19,6 +19,14 @@ class WanVaceModel(WanModel):
def
__init__
(
self
,
model_path
,
config
,
device
):
def
__init__
(
self
,
model_path
,
config
,
device
):
super
().
__init__
(
model_path
,
config
,
device
)
super
().
__init__
(
model_path
,
config
,
device
)
def
_init_infer
(
self
):
super
().
_init_infer
()
if
hasattr
(
self
.
transformer_infer
,
"offload_manager"
):
self
.
transformer_infer
.
offload_block_buffers
=
self
.
transformer_weights
.
offload_block_buffers
self
.
transformer_infer
.
offload_phase_buffers
=
self
.
transformer_weights
.
offload_phase_buffers
self
.
transformer_infer
.
vace_offload_block_buffers
=
self
.
transformer_weights
.
vace_offload_block_buffers
self
.
transformer_infer
.
vace_offload_phase_buffers
=
self
.
transformer_weights
.
vace_offload_phase_buffers
def
_init_infer_class
(
self
):
def
_init_infer_class
(
self
):
self
.
pre_infer_class
=
WanPreInfer
self
.
pre_infer_class
=
WanPreInfer
self
.
post_infer_class
=
WanPostInfer
self
.
post_infer_class
=
WanPostInfer
...
...
lightx2v/models/networks/wan/weights/animate/transformer_weights.py
View file @
51be3ad2
...
@@ -6,7 +6,12 @@ from lightx2v.common.modules.weight_module import WeightModule
...
@@ -6,7 +6,12 @@ from lightx2v.common.modules.weight_module import WeightModule
from
lightx2v.models.networks.wan.weights.transformer_weights
import
(
from
lightx2v.models.networks.wan.weights.transformer_weights
import
(
WanTransformerWeights
,
WanTransformerWeights
,
)
)
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
MM_WEIGHT_REGISTER
,
RMS_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
(
ATTN_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
MM_WEIGHT_REGISTER
,
RMS_WEIGHT_REGISTER
,
)
class
WanAnimateTransformerWeights
(
WanTransformerWeights
):
class
WanAnimateTransformerWeights
(
WanTransformerWeights
):
...
@@ -18,39 +23,74 @@ class WanAnimateTransformerWeights(WanTransformerWeights):
...
@@ -18,39 +23,74 @@ class WanAnimateTransformerWeights(WanTransformerWeights):
self
.
blocks
[
i
].
compute_phases
.
append
(
WanAnimateFuserBlock
(
self
.
config
,
i
//
5
,
"face_adapter.fuser_blocks"
,
self
.
mm_type
))
self
.
blocks
[
i
].
compute_phases
.
append
(
WanAnimateFuserBlock
(
self
.
config
,
i
//
5
,
"face_adapter.fuser_blocks"
,
self
.
mm_type
))
else
:
else
:
self
.
blocks
[
i
].
compute_phases
.
append
(
WeightModule
())
self
.
blocks
[
i
].
compute_phases
.
append
(
WeightModule
())
self
.
_add_animate_fuserblock_to_offload_buffers
()
def
_add_animate_fuserblock_to_offload_buffers
(
self
):
if
hasattr
(
self
,
"offload_block_buffers"
)
and
self
.
offload_block_buffers
is
not
None
:
for
i
in
range
(
self
.
offload_blocks_num
):
self
.
offload_block_buffers
[
i
].
compute_phases
.
append
(
WanAnimateFuserBlock
(
self
.
config
,
0
,
"face_adapter.fuser_blocks"
,
self
.
mm_type
,
is_offload_buffer
=
True
))
elif
hasattr
(
self
,
"offload_phase_buffers"
)
and
self
.
offload_phase_buffers
is
not
None
:
self
.
offload_phase_buffers
.
append
(
WanAnimateFuserBlock
(
self
.
config
,
0
,
"face_adapter.fuser_blocks"
,
self
.
mm_type
,
is_offload_buffer
=
True
))
class
WanAnimateFuserBlock
(
WeightModule
):
class
WanAnimateFuserBlock
(
WeightModule
):
def
__init__
(
self
,
config
,
block_index
,
block_prefix
,
mm_type
):
def
__init__
(
self
,
config
,
block_index
,
block_prefix
,
mm_type
,
is_offload_buffer
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
is_post_adapter
=
True
lazy_load
=
config
.
get
(
"lazy_load"
,
False
)
lazy_load
=
config
.
get
(
"lazy_load"
,
False
)
if
lazy_load
:
if
lazy_load
:
lazy_load_path
=
os
.
path
.
join
(
config
.
dit_quantized_ckpt
,
f
"
{
block_prefix
[:
-
1
]
}
_
{
block_index
}
.safetensors"
)
lazy_load_path
=
os
.
path
.
join
(
config
.
dit_quantized_ckpt
,
f
"
{
block_prefix
[:
-
1
]
}
_
{
block_index
}
.safetensors"
,
)
lazy_load_file
=
safe_open
(
lazy_load_path
,
framework
=
"pt"
,
device
=
"cpu"
)
lazy_load_file
=
safe_open
(
lazy_load_path
,
framework
=
"pt"
,
device
=
"cpu"
)
else
:
else
:
lazy_load_file
=
None
lazy_load_file
=
None
self
.
add_module
(
self
.
add_module
(
"linear1_kv"
,
"linear1_kv"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_kv.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_kv.bias"
,
lazy_load
,
lazy_load_file
),
MM_WEIGHT_REGISTER
[
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_kv.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_kv.bias"
,
is_offload_buffer
,
lazy_load
,
lazy_load_file
,
self
.
is_post_adapter
,
),
)
)
self
.
add_module
(
self
.
add_module
(
"linear1_q"
,
"linear1_q"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_q.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_q.bias"
,
lazy_load
,
lazy_load_file
),
MM_WEIGHT_REGISTER
[
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_q.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_q.bias"
,
is_offload_buffer
,
lazy_load
,
lazy_load_file
,
self
.
is_post_adapter
,
),
)
)
self
.
add_module
(
self
.
add_module
(
"linear2"
,
"linear2"
,
MM_WEIGHT_REGISTER
[
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.linear2.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.linear2.bias"
,
lazy_load
,
lazy_load_file
),
MM_WEIGHT_REGISTER
[
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.linear2.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.linear2.bias"
,
is_offload_buffer
,
lazy_load
,
lazy_load_file
,
self
.
is_post_adapter
,
),
)
)
self
.
add_module
(
self
.
add_module
(
"q_norm"
,
"q_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"
{
block_prefix
}
.
{
block_index
}
.q_norm.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.q_norm.weight"
,
is_offload_buffer
,
lazy_load
,
lazy_load
,
lazy_load_file
,
lazy_load_file
,
self
.
is_post_adapter
,
),
),
)
)
...
@@ -58,8 +98,10 @@ class WanAnimateFuserBlock(WeightModule):
...
@@ -58,8 +98,10 @@ class WanAnimateFuserBlock(WeightModule):
"k_norm"
,
"k_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"
{
block_prefix
}
.
{
block_index
}
.k_norm.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.k_norm.weight"
,
is_offload_buffer
,
lazy_load
,
lazy_load
,
lazy_load_file
,
lazy_load_file
,
self
.
is_post_adapter
,
),
),
)
)
...
@@ -67,6 +109,7 @@ class WanAnimateFuserBlock(WeightModule):
...
@@ -67,6 +109,7 @@ class WanAnimateFuserBlock(WeightModule):
"pre_norm_feat"
,
"pre_norm_feat"
,
LN_WEIGHT_REGISTER
[
"Default"
](),
LN_WEIGHT_REGISTER
[
"Default"
](),
)
)
self
.
add_module
(
self
.
add_module
(
"pre_norm_motion"
,
"pre_norm_motion"
,
LN_WEIGHT_REGISTER
[
"Default"
](),
LN_WEIGHT_REGISTER
[
"Default"
](),
...
...
lightx2v/models/networks/wan/weights/audio/transformer_weights.py
View file @
51be3ad2
...
@@ -18,14 +18,46 @@ class WanAudioTransformerWeights(WanTransformerWeights):
...
@@ -18,14 +18,46 @@ class WanAudioTransformerWeights(WanTransformerWeights):
self
.
task
,
self
.
task
,
self
.
mm_type
,
self
.
mm_type
,
self
.
config
,
self
.
config
,
False
,
self
.
blocks
[
i
].
lazy_load
,
self
.
blocks
[
i
].
lazy_load
,
self
.
blocks
[
i
].
lazy_load_file
,
self
.
blocks
[
i
].
lazy_load_file
,
)
)
)
)
self
.
_add_audio_adapter_ca_to_offload_buffers
()
def
_add_audio_adapter_ca_to_offload_buffers
(
self
):
if
hasattr
(
self
,
"offload_block_buffers"
)
and
self
.
offload_block_buffers
is
not
None
:
for
i
in
range
(
self
.
offload_blocks_num
):
offload_buffer
=
self
.
offload_block_buffers
[
i
]
adapter_ca
=
WanAudioAdapterCA
(
block_index
=
i
,
block_prefix
=
f
"ca"
,
task
=
self
.
task
,
mm_type
=
self
.
mm_type
,
config
=
self
.
config
,
is_offload_buffer
=
True
,
lazy_load
=
offload_buffer
.
lazy_load
,
lazy_load_file
=
offload_buffer
.
lazy_load_file
,
)
offload_buffer
.
compute_phases
.
append
(
adapter_ca
)
elif
hasattr
(
self
,
"offload_phase_buffers"
)
and
self
.
offload_phase_buffers
is
not
None
:
adapter_ca
=
WanAudioAdapterCA
(
block_index
=
0
,
block_prefix
=
f
"ca"
,
task
=
self
.
task
,
mm_type
=
self
.
mm_type
,
config
=
self
.
config
,
is_offload_buffer
=
True
,
lazy_load
=
self
.
blocks
[
0
].
lazy_load
,
lazy_load_file
=
self
.
blocks
[
0
].
lazy_load_file
,
)
self
.
offload_phase_buffers
.
append
(
adapter_ca
)
class
WanAudioAdapterCA
(
WeightModule
):
class
WanAudioAdapterCA
(
WeightModule
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
lazy_load
,
lazy_load_file
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
is_offload_buffer
,
lazy_load
,
lazy_load_file
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
mm_type
=
mm_type
...
@@ -39,6 +71,7 @@ class WanAudioAdapterCA(WeightModule):
...
@@ -39,6 +71,7 @@ class WanAudioAdapterCA(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.to_q.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.to_q.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.to_q.bias"
,
f
"
{
block_prefix
}
.
{
block_index
}
.to_q.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -49,6 +82,7 @@ class WanAudioAdapterCA(WeightModule):
...
@@ -49,6 +82,7 @@ class WanAudioAdapterCA(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.to_kv.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.to_kv.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.to_kv.bias"
,
f
"
{
block_prefix
}
.
{
block_index
}
.to_kv.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -59,6 +93,7 @@ class WanAudioAdapterCA(WeightModule):
...
@@ -59,6 +93,7 @@ class WanAudioAdapterCA(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.to_out.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.to_out.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.to_out.bias"
,
f
"
{
block_prefix
}
.
{
block_index
}
.to_out.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -69,6 +104,7 @@ class WanAudioAdapterCA(WeightModule):
...
@@ -69,6 +104,7 @@ class WanAudioAdapterCA(WeightModule):
LN_WEIGHT_REGISTER
[
"Default"
](
LN_WEIGHT_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
block_index
}
.norm_kv.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.norm_kv.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.norm_kv.bias"
,
f
"
{
block_prefix
}
.
{
block_index
}
.norm_kv.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -83,6 +119,7 @@ class WanAudioAdapterCA(WeightModule):
...
@@ -83,6 +119,7 @@ class WanAudioAdapterCA(WeightModule):
"shift_scale_gate"
,
"shift_scale_gate"
,
TENSOR_REGISTER
[
"Default"
](
TENSOR_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
block_index
}
.shift_scale_gate"
,
f
"
{
block_prefix
}
.
{
block_index
}
.shift_scale_gate"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
...
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
51be3ad2
...
@@ -24,6 +24,7 @@ class WanTransformerWeights(WeightModule):
...
@@ -24,6 +24,7 @@ class WanTransformerWeights(WeightModule):
if
config
.
get
(
"do_mm_calib"
,
False
):
if
config
.
get
(
"do_mm_calib"
,
False
):
self
.
mm_type
=
"Calib"
self
.
mm_type
=
"Calib"
self
.
blocks
=
WeightModuleList
([
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
)
for
i
in
range
(
self
.
blocks_num
)])
self
.
blocks
=
WeightModuleList
([
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
)
for
i
in
range
(
self
.
blocks_num
)])
self
.
register_offload_buffers
(
config
)
self
.
add_module
(
"blocks"
,
self
.
blocks
)
self
.
add_module
(
"blocks"
,
self
.
blocks
)
# non blocks weights
# non blocks weights
...
@@ -31,6 +32,35 @@ class WanTransformerWeights(WeightModule):
...
@@ -31,6 +32,35 @@ class WanTransformerWeights(WeightModule):
self
.
add_module
(
"head"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"head.head.weight"
,
"head.head.bias"
))
self
.
add_module
(
"head"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"head.head.weight"
,
"head.head.bias"
))
self
.
register_parameter
(
"head_modulation"
,
TENSOR_REGISTER
[
"Default"
](
"head.modulation"
))
self
.
register_parameter
(
"head_modulation"
,
TENSOR_REGISTER
[
"Default"
](
"head.modulation"
))
def
register_offload_buffers
(
self
,
config
):
if
config
[
"cpu_offload"
]:
if
config
[
"offload_granularity"
]
==
"block"
:
self
.
offload_blocks_num
=
2
self
.
offload_block_buffers
=
WeightModuleList
(
[
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
True
,
)
for
i
in
range
(
self
.
offload_blocks_num
)
]
)
self
.
add_module
(
"offload_block_buffers"
,
self
.
offload_block_buffers
)
self
.
offload_phase_buffers
=
None
elif
config
[
"offload_granularity"
]
==
"phase"
:
self
.
offload_phase_buffers
=
WanTransformerAttentionBlock
(
0
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
True
,
).
compute_phases
self
.
add_module
(
"offload_phase_buffers"
,
self
.
offload_phase_buffers
)
self
.
offload_block_buffers
=
None
def
clear
(
self
):
def
clear
(
self
):
for
block
in
self
.
blocks
:
for
block
in
self
.
blocks
:
for
phase
in
block
.
compute_phases
:
for
phase
in
block
.
compute_phases
:
...
@@ -48,12 +78,21 @@ class WanTransformerWeights(WeightModule):
...
@@ -48,12 +78,21 @@ class WanTransformerWeights(WeightModule):
class
WanTransformerAttentionBlock
(
WeightModule
):
class
WanTransformerAttentionBlock
(
WeightModule
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
,
block_prefix
=
"blocks"
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
,
is_offload_buffer
=
False
,
block_prefix
=
"blocks"
,
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
mm_type
=
mm_type
self
.
task
=
task
self
.
task
=
task
self
.
config
=
config
self
.
config
=
config
self
.
is_offload_buffer
=
is_offload_buffer
self
.
quant_method
=
config
.
get
(
"quant_method"
,
None
)
self
.
quant_method
=
config
.
get
(
"quant_method"
,
None
)
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
...
@@ -71,6 +110,7 @@ class WanTransformerAttentionBlock(WeightModule):
...
@@ -71,6 +110,7 @@ class WanTransformerAttentionBlock(WeightModule):
task
,
task
,
mm_type
,
mm_type
,
config
,
config
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -80,6 +120,7 @@ class WanTransformerAttentionBlock(WeightModule):
...
@@ -80,6 +120,7 @@ class WanTransformerAttentionBlock(WeightModule):
task
,
task
,
mm_type
,
mm_type
,
config
,
config
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -89,6 +130,7 @@ class WanTransformerAttentionBlock(WeightModule):
...
@@ -89,6 +130,7 @@ class WanTransformerAttentionBlock(WeightModule):
task
,
task
,
mm_type
,
mm_type
,
config
,
config
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -99,7 +141,17 @@ class WanTransformerAttentionBlock(WeightModule):
...
@@ -99,7 +141,17 @@ class WanTransformerAttentionBlock(WeightModule):
class
WanSelfAttention
(
WeightModule
):
class
WanSelfAttention
(
WeightModule
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
lazy_load
,
lazy_load_file
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
is_offload_buffer
,
lazy_load
,
lazy_load_file
,
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
mm_type
=
mm_type
...
@@ -119,6 +171,7 @@ class WanSelfAttention(WeightModule):
...
@@ -119,6 +171,7 @@ class WanSelfAttention(WeightModule):
"modulation"
,
"modulation"
,
TENSOR_REGISTER
[
"Default"
](
TENSOR_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.modulation"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.modulation"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -134,6 +187,7 @@ class WanSelfAttention(WeightModule):
...
@@ -134,6 +187,7 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.q.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.q.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -144,6 +198,7 @@ class WanSelfAttention(WeightModule):
...
@@ -144,6 +198,7 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.k.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.k.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -153,6 +208,7 @@ class WanSelfAttention(WeightModule):
...
@@ -153,6 +208,7 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.v.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.v.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -162,6 +218,7 @@ class WanSelfAttention(WeightModule):
...
@@ -162,6 +218,7 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.o.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.o.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.o.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.o.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -170,6 +227,7 @@ class WanSelfAttention(WeightModule):
...
@@ -170,6 +227,7 @@ class WanSelfAttention(WeightModule):
"self_attn_norm_q"
,
"self_attn_norm_q"
,
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.norm_q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.norm_q.weight"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -178,6 +236,7 @@ class WanSelfAttention(WeightModule):
...
@@ -178,6 +236,7 @@ class WanSelfAttention(WeightModule):
"self_attn_norm_k"
,
"self_attn_norm_k"
,
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.norm_k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.norm_k.weight"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -192,7 +251,12 @@ class WanSelfAttention(WeightModule):
...
@@ -192,7 +251,12 @@ class WanSelfAttention(WeightModule):
context_length
=
self
.
config
.
get
(
"svg_context_length"
,
0
),
context_length
=
self
.
config
.
get
(
"svg_context_length"
,
0
),
sparsity
=
self
.
config
.
get
(
"svg_sparsity"
,
0.25
),
sparsity
=
self
.
config
.
get
(
"svg_sparsity"
,
0.25
),
)
)
if
self
.
config
[
"self_attn_1_type"
]
in
[
"svg_attn"
,
"radial_attn"
,
"nbhd_attn"
,
"nbhd_attn_flashinfer"
]:
if
self
.
config
[
"self_attn_1_type"
]
in
[
"svg_attn"
,
"radial_attn"
,
"nbhd_attn"
,
"nbhd_attn_flashinfer"
,
]:
attention_weights_cls
.
attnmap_frame_num
=
self
.
config
[
"attnmap_frame_num"
]
attention_weights_cls
.
attnmap_frame_num
=
self
.
config
[
"attnmap_frame_num"
]
# nbhd_attn setting
# nbhd_attn setting
if
self
.
config
[
"self_attn_1_type"
]
in
[
"nbhd_attn"
,
"nbhd_attn_flashinfer"
]:
if
self
.
config
[
"self_attn_1_type"
]
in
[
"nbhd_attn"
,
"nbhd_attn_flashinfer"
]:
...
@@ -204,13 +268,17 @@ class WanSelfAttention(WeightModule):
...
@@ -204,13 +268,17 @@ class WanSelfAttention(WeightModule):
self
.
add_module
(
"self_attn_1"
,
attention_weights_cls
())
self
.
add_module
(
"self_attn_1"
,
attention_weights_cls
())
if
self
.
config
[
"seq_parallel"
]:
if
self
.
config
[
"seq_parallel"
]:
self
.
add_module
(
"self_attn_1_parallel"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"parallel"
].
get
(
"seq_p_attn_type"
,
"ulysses"
)]())
self
.
add_module
(
"self_attn_1_parallel"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"parallel"
].
get
(
"seq_p_attn_type"
,
"ulysses"
)](),
)
if
self
.
quant_method
in
[
"advanced_ptq"
]:
if
self
.
quant_method
in
[
"advanced_ptq"
]:
self
.
add_module
(
self
.
add_module
(
"smooth_norm1_weight"
,
"smooth_norm1_weight"
,
TENSOR_REGISTER
[
"Default"
](
TENSOR_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.affine_norm1.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.affine_norm1.weight"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -219,6 +287,7 @@ class WanSelfAttention(WeightModule):
...
@@ -219,6 +287,7 @@ class WanSelfAttention(WeightModule):
"smooth_norm1_bias"
,
"smooth_norm1_bias"
,
TENSOR_REGISTER
[
"Default"
](
TENSOR_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.affine_norm1.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.affine_norm1.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -226,7 +295,17 @@ class WanSelfAttention(WeightModule):
...
@@ -226,7 +295,17 @@ class WanSelfAttention(WeightModule):
class
WanCrossAttention
(
WeightModule
):
class
WanCrossAttention
(
WeightModule
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
lazy_load
,
lazy_load_file
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
is_offload_buffer
,
lazy_load
,
lazy_load_file
,
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
mm_type
=
mm_type
...
@@ -245,6 +324,7 @@ class WanCrossAttention(WeightModule):
...
@@ -245,6 +324,7 @@ class WanCrossAttention(WeightModule):
LN_WEIGHT_REGISTER
[
"Default"
](
LN_WEIGHT_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.norm3.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.norm3.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.norm3.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.norm3.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -254,6 +334,7 @@ class WanCrossAttention(WeightModule):
...
@@ -254,6 +334,7 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.q.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.q.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -263,6 +344,7 @@ class WanCrossAttention(WeightModule):
...
@@ -263,6 +344,7 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -272,6 +354,7 @@ class WanCrossAttention(WeightModule):
...
@@ -272,6 +354,7 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -281,6 +364,7 @@ class WanCrossAttention(WeightModule):
...
@@ -281,6 +364,7 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.o.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.o.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.o.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.o.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -289,6 +373,7 @@ class WanCrossAttention(WeightModule):
...
@@ -289,6 +373,7 @@ class WanCrossAttention(WeightModule):
"cross_attn_norm_q"
,
"cross_attn_norm_q"
,
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.norm_q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.norm_q.weight"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -297,6 +382,7 @@ class WanCrossAttention(WeightModule):
...
@@ -297,6 +382,7 @@ class WanCrossAttention(WeightModule):
"cross_attn_norm_k"
,
"cross_attn_norm_k"
,
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.norm_k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.norm_k.weight"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -309,6 +395,7 @@ class WanCrossAttention(WeightModule):
...
@@ -309,6 +395,7 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k_img.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k_img.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k_img.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k_img.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -318,6 +405,7 @@ class WanCrossAttention(WeightModule):
...
@@ -318,6 +405,7 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v_img.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v_img.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v_img.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v_img.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -326,6 +414,7 @@ class WanCrossAttention(WeightModule):
...
@@ -326,6 +414,7 @@ class WanCrossAttention(WeightModule):
"cross_attn_norm_k_img"
,
"cross_attn_norm_k_img"
,
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.norm_k_img.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.norm_k_img.weight"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -334,7 +423,17 @@ class WanCrossAttention(WeightModule):
...
@@ -334,7 +423,17 @@ class WanCrossAttention(WeightModule):
class
WanFFN
(
WeightModule
):
class
WanFFN
(
WeightModule
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
lazy_load
,
lazy_load_file
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
is_offload_buffer
,
lazy_load
,
lazy_load_file
,
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
mm_type
=
mm_type
...
@@ -354,6 +453,7 @@ class WanFFN(WeightModule):
...
@@ -354,6 +453,7 @@ class WanFFN(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.ffn.0.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.ffn.0.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.ffn.0.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.ffn.0.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -363,6 +463,7 @@ class WanFFN(WeightModule):
...
@@ -363,6 +463,7 @@ class WanFFN(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.ffn.2.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.ffn.2.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.ffn.2.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.ffn.2.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -373,6 +474,7 @@ class WanFFN(WeightModule):
...
@@ -373,6 +474,7 @@ class WanFFN(WeightModule):
"smooth_norm2_weight"
,
"smooth_norm2_weight"
,
TENSOR_REGISTER
[
"Default"
](
TENSOR_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.affine_norm3.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.affine_norm3.weight"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -381,6 +483,7 @@ class WanFFN(WeightModule):
...
@@ -381,6 +483,7 @@ class WanFFN(WeightModule):
"smooth_norm2_bias"
,
"smooth_norm2_bias"
,
TENSOR_REGISTER
[
"Default"
](
TENSOR_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.affine_norm3.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.affine_norm3.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
...
lightx2v/models/networks/wan/weights/vace/transformer_weights.py
View file @
51be3ad2
...
@@ -13,16 +13,31 @@ class WanVaceTransformerWeights(WanTransformerWeights):
...
@@ -13,16 +13,31 @@ class WanVaceTransformerWeights(WanTransformerWeights):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
register_offload_buffers
(
config
)
self
.
vace_blocks
=
WeightModuleList
(
self
.
vace_blocks
=
WeightModuleList
(
[
WanVaceTransformerAttentionBlock
(
self
.
config
[
"vace_layers"
][
i
],
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
"vace_blocks"
)
for
i
in
range
(
len
(
self
.
config
[
"vace_layers"
]))]
[
WanVaceTransformerAttentionBlock
(
self
.
config
[
"vace_layers"
][
i
],
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
False
,
"vace_blocks"
)
for
i
in
range
(
len
(
self
.
config
[
"vace_layers"
]))]
)
)
self
.
add_module
(
"vace_blocks"
,
self
.
vace_blocks
)
self
.
add_module
(
"vace_blocks"
,
self
.
vace_blocks
)
self
.
add_module
(
self
.
add_module
(
"vace_patch_embedding"
,
"vace_patch_embedding"
,
CONV3D_WEIGHT_REGISTER
[
"Default"
](
"vace_patch_embedding.weight"
,
"vace_patch_embedding.bias"
,
stride
=
self
.
patch_size
),
CONV3D_WEIGHT_REGISTER
[
"Default"
](
"vace_patch_embedding.weight"
,
"vace_patch_embedding.bias"
,
stride
=
self
.
patch_size
),
)
)
def
register_offload_buffers
(
self
,
config
):
super
().
register_offload_buffers
(
config
)
if
config
[
"cpu_offload"
]:
if
config
[
"offload_granularity"
]
==
"block"
:
self
.
vace_offload_block_buffers
=
WeightModuleList
(
[
WanVaceTransformerAttentionBlock
(
self
.
config
[
"vace_layers"
][
0
],
0
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
True
,
"vace_blocks"
),
WanVaceTransformerAttentionBlock
(
self
.
config
[
"vace_layers"
][
0
],
0
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
True
,
"vace_blocks"
),
]
)
self
.
add_module
(
"vace_offload_block_buffers"
,
self
.
vace_offload_block_buffers
)
self
.
vace_offload_phase_buffers
=
None
elif
config
[
"offload_granularity"
]
==
"phase"
:
raise
NotImplementedError
def
clear
(
self
):
def
clear
(
self
):
super
().
clear
()
super
().
clear
()
for
vace_block
in
self
.
vace_blocks
:
for
vace_block
in
self
.
vace_blocks
:
...
@@ -39,14 +54,15 @@ class WanVaceTransformerWeights(WanTransformerWeights):
...
@@ -39,14 +54,15 @@ class WanVaceTransformerWeights(WanTransformerWeights):
class
WanVaceTransformerAttentionBlock
(
WanTransformerAttentionBlock
):
class
WanVaceTransformerAttentionBlock
(
WanTransformerAttentionBlock
):
def
__init__
(
self
,
base_block_idx
,
block_index
,
task
,
mm_type
,
config
,
block_prefix
):
def
__init__
(
self
,
base_block_idx
,
block_index
,
task
,
mm_type
,
config
,
is_offload_buffer
,
block_prefix
):
super
().
__init__
(
block_index
,
task
,
mm_type
,
config
,
block_prefix
)
super
().
__init__
(
block_index
,
task
,
mm_type
,
config
,
is_offload_buffer
,
block_prefix
)
if
base_block_idx
==
0
:
if
base_block_idx
==
0
:
self
.
compute_phases
[
0
].
add_module
(
self
.
compute_phases
[
0
].
add_module
(
"before_proj"
,
"before_proj"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.before_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.before_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.before_proj.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.before_proj.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -57,6 +73,7 @@ class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock):
...
@@ -57,6 +73,7 @@ class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.after_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.after_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.after_proj.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.after_proj.bias"
,
is_offload_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
...
lightx2v/models/runners/wan/wan_distill_runner.py
View file @
51be3ad2
...
@@ -78,21 +78,27 @@ class MultiDistillModelStruct(MultiModelStruct):
...
@@ -78,21 +78,27 @@ class MultiDistillModelStruct(MultiModelStruct):
class
Wan22MoeDistillRunner
(
WanDistillRunner
):
class
Wan22MoeDistillRunner
(
WanDistillRunner
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
high_noise_model_path
=
os
.
path
.
join
(
self
.
config
[
"model_path"
],
"high_noise_model"
)
if
not
os
.
path
.
isdir
(
self
.
high_noise_model_path
):
self
.
high_noise_model_path
=
os
.
path
.
join
(
self
.
config
[
"model_path"
],
"distill_models"
,
"high_noise_model"
)
if
self
.
config
.
get
(
"dit_quantized"
,
False
)
and
self
.
config
.
get
(
"high_noise_quantized_ckpt"
,
None
):
if
self
.
config
.
get
(
"dit_quantized"
,
False
)
and
self
.
config
.
get
(
"high_noise_quantized_ckpt"
,
None
):
self
.
high_noise_model_path
=
self
.
config
[
"high_noise_quantized_ckpt"
]
self
.
high_noise_model_path
=
self
.
config
[
"high_noise_quantized_ckpt"
]
elif
self
.
config
.
get
(
"high_noise_original_ckpt"
,
None
):
elif
self
.
config
.
get
(
"high_noise_original_ckpt"
,
None
):
self
.
high_noise_model_path
=
self
.
config
[
"high_noise_original_ckpt"
]
self
.
high_noise_model_path
=
self
.
config
[
"high_noise_original_ckpt"
]
else
:
self
.
high_noise_model_path
=
os
.
path
.
join
(
self
.
config
[
"model_path"
],
"high_noise_model"
)
if
not
os
.
path
.
isdir
(
self
.
high_noise_model_path
):
self
.
high_noise_model_path
=
os
.
path
.
join
(
self
.
config
[
"model_path"
],
"distill_models"
,
"high_noise_model"
)
if
not
os
.
path
.
isdir
(
self
.
high_noise_model_path
):
raise
FileNotFoundError
(
f
"High Noise Model does not find"
)
self
.
low_noise_model_path
=
os
.
path
.
join
(
self
.
config
[
"model_path"
],
"low_noise_model"
)
if
not
os
.
path
.
isdir
(
self
.
low_noise_model_path
):
self
.
low_noise_model_path
=
os
.
path
.
join
(
self
.
config
[
"model_path"
],
"distill_models"
,
"low_noise_model"
)
if
self
.
config
.
get
(
"dit_quantized"
,
False
)
and
self
.
config
.
get
(
"low_noise_quantized_ckpt"
,
None
):
if
self
.
config
.
get
(
"dit_quantized"
,
False
)
and
self
.
config
.
get
(
"low_noise_quantized_ckpt"
,
None
):
self
.
low_noise_model_path
=
self
.
config
[
"low_noise_quantized_ckpt"
]
self
.
low_noise_model_path
=
self
.
config
[
"low_noise_quantized_ckpt"
]
elif
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
and
self
.
config
.
get
(
"low_noise_original_ckpt"
,
None
):
elif
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
and
self
.
config
.
get
(
"low_noise_original_ckpt"
,
None
):
self
.
low_noise_model_path
=
self
.
config
[
"low_noise_original_ckpt"
]
self
.
low_noise_model_path
=
self
.
config
[
"low_noise_original_ckpt"
]
else
:
self
.
low_noise_model_path
=
os
.
path
.
join
(
self
.
config
[
"model_path"
],
"low_noise_model"
)
if
not
os
.
path
.
isdir
(
self
.
low_noise_model_path
):
self
.
low_noise_model_path
=
os
.
path
.
join
(
self
.
config
[
"model_path"
],
"distill_models"
,
"low_noise_model"
)
if
not
os
.
path
.
isdir
(
self
.
high_noise_model_path
):
raise
FileNotFoundError
(
f
"Low Noise Model does not find"
)
def
load_transformer
(
self
):
def
load_transformer
(
self
):
use_high_lora
,
use_low_lora
=
False
,
False
use_high_lora
,
use_low_lora
=
False
,
False
...
...
requirements_animate.txt
0 → 100644
View file @
51be3ad2
decord
peft
onnxruntime
pandas
matplotlib
-e git+https://github.com/facebookresearch/sam2.git@0e78a118995e66bb27d78518c4bd9a3e95b4e266#egg=SAM-2
loguru
sentencepiece
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment