Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
74eeb429
Unverified
Commit
74eeb429
authored
Dec 03, 2025
by
Gu Shiqiao
Committed by
GitHub
Dec 03, 2025
Browse files
reconstruct disk offload and fix lightx2v_platform bugs (#558)
Co-authored-by:
helloyongyang
<
yongyang1030@163.com
>
parent
f7cdbcb5
Changes
46
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
391 additions
and
380 deletions
+391
-380
lightx2v/models/networks/qwen_image/weights/post_weights.py
lightx2v/models/networks/qwen_image/weights/post_weights.py
+10
-0
lightx2v/models/networks/qwen_image/weights/pre_weights.py
lightx2v/models/networks/qwen_image/weights/pre_weights.py
+10
-0
lightx2v/models/networks/qwen_image/weights/transformer_weights.py
...models/networks/qwen_image/weights/transformer_weights.py
+46
-31
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
...s/networks/wan/infer/feature_caching/transformer_infer.py
+8
-5
lightx2v/models/networks/wan/infer/offload/transformer_infer.py
...2v/models/networks/wan/infer/offload/transformer_infer.py
+24
-105
lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py
lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py
+2
-1
lightx2v/models/networks/wan/infer/vace/transformer_infer.py
lightx2v/models/networks/wan/infer/vace/transformer_infer.py
+3
-2
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+21
-30
lightx2v/models/networks/wan/vace_model.py
lightx2v/models/networks/wan/vace_model.py
+9
-4
lightx2v/models/networks/wan/weights/animate/transformer_weights.py
...odels/networks/wan/weights/animate/transformer_weights.py
+19
-10
lightx2v/models/networks/wan/weights/audio/transformer_weights.py
.../models/networks/wan/weights/audio/transformer_weights.py
+47
-12
lightx2v/models/networks/wan/weights/matrix_game2/transformer_weights.py
.../networks/wan/weights/matrix_game2/transformer_weights.py
+6
-77
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+154
-68
lightx2v/models/networks/wan/weights/vace/transformer_weights.py
...v/models/networks/wan/weights/vace/transformer_weights.py
+12
-16
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+3
-6
lightx2v/models/runners/wan/wan_animate_runner.py
lightx2v/models/runners/wan/wan_animate_runner.py
+6
-5
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+1
-1
lightx2v/models/runners/wan/wan_matrix_game2_runner.py
lightx2v/models/runners/wan/wan_matrix_game2_runner.py
+6
-5
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+2
-2
lightx2v/models/schedulers/hunyuan_video/posemb_layers.py
lightx2v/models/schedulers/hunyuan_video/posemb_layers.py
+2
-0
No files found.
lightx2v/models/networks/qwen_image/weights/post_weights.py
View file @
74eeb429
...
...
@@ -37,3 +37,13 @@ class QwenImagePostWeights(WeightModule):
self
.
lazy_load_file
,
),
)
def
to_cpu
(
self
,
non_blocking
=
True
):
for
module
in
self
.
_modules
.
values
():
if
module
is
not
None
and
hasattr
(
module
,
"to_cpu"
):
module
.
to_cpu
(
non_blocking
=
non_blocking
)
def
to_cuda
(
self
,
non_blocking
=
True
):
for
module
in
self
.
_modules
.
values
():
if
module
is
not
None
and
hasattr
(
module
,
"to_cuda"
):
module
.
to_cuda
(
non_blocking
=
non_blocking
)
lightx2v/models/networks/qwen_image/weights/pre_weights.py
100644 → 100755
View file @
74eeb429
...
...
@@ -28,3 +28,13 @@ class QwenImagePreWeights(WeightModule):
self
.
add_module
(
"time_text_embed_timestep_embedder_linear_2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"time_text_embed.timestep_embedder.linear_2.weight"
,
"time_text_embed.timestep_embedder.linear_2.bias"
)
)
def
to_cpu
(
self
,
non_blocking
=
True
):
for
module
in
self
.
_modules
.
values
():
if
module
is
not
None
and
hasattr
(
module
,
"to_cpu"
):
module
.
to_cpu
(
non_blocking
=
non_blocking
)
def
to_cuda
(
self
,
non_blocking
=
True
):
for
module
in
self
.
_modules
.
values
():
if
module
is
not
None
and
hasattr
(
module
,
"to_cuda"
):
module
.
to_cuda
(
non_blocking
=
non_blocking
)
lightx2v/models/networks/qwen_image/weights/transformer_weights.py
View file @
74eeb429
...
...
@@ -15,7 +15,7 @@ class QwenImageTransformerWeights(WeightModule):
self
.
mm_type
=
config
.
get
(
"dit_quant_scheme"
,
"Default"
)
if
self
.
mm_type
!=
"Default"
:
assert
config
.
get
(
"dit_quantized"
)
is
True
blocks
=
WeightModuleList
(
QwenImageTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
False
,
"transformer_blocks"
)
for
i
in
range
(
self
.
blocks_num
))
blocks
=
WeightModuleList
(
QwenImageTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
False
,
False
,
"transformer_blocks"
)
for
i
in
range
(
self
.
blocks_num
))
self
.
register_offload_buffers
(
config
)
self
.
add_module
(
"blocks"
,
blocks
)
...
...
@@ -23,17 +23,17 @@ class QwenImageTransformerWeights(WeightModule):
if
config
[
"cpu_offload"
]:
if
config
[
"offload_granularity"
]
==
"block"
:
self
.
offload_blocks_num
=
2
self
.
offload_block_buffers
=
WeightModuleList
(
[
QwenImageTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
True
,
"transformer_blocks"
)
for
i
in
range
(
self
.
offload_blocks_num
)]
self
.
offload_block_
cuda_
buffers
=
WeightModuleList
(
[
QwenImageTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
True
,
False
,
"transformer_blocks"
)
for
i
in
range
(
self
.
offload_blocks_num
)]
)
self
.
add_module
(
"offload_block_buffers"
,
self
.
offload_block_buffers
)
self
.
offload_phase_buffers
=
None
self
.
add_module
(
"offload_block_
cuda_
buffers"
,
self
.
offload_block_
cuda_
buffers
)
self
.
offload_phase_
cuda_
buffers
=
None
else
:
raise
NotImplementedError
class
QwenImageTransformerAttentionBlock
(
WeightModule
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
,
is_offload
_buffer
=
False
,
block_prefix
=
"transformer_blocks"
):
def
__init__
(
self
,
block_index
,
task
,
mm_type
,
config
,
create_cuda_buffer
=
False
,
create_cpu
_buffer
=
False
,
block_prefix
=
"transformer_blocks"
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
...
...
@@ -55,14 +55,15 @@ class QwenImageTransformerAttentionBlock(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mod.1.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mod.1.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
)
self
.
add_module
(
"img_norm1"
,
LN_WEIGHT_REGISTER
[
"Default"
](
create_cuda_buffer
=
is_offload
_buffer
,
eps
=
1e-6
),
LN_WEIGHT_REGISTER
[
"Default"
](
create_cuda_buffer
=
create_cuda_buffer
,
create_cpu_buffer
=
create_cpu
_buffer
,
eps
=
1e-6
),
)
self
.
attn
=
QwenImageCrossAttention
(
block_index
=
block_index
,
...
...
@@ -70,7 +71,8 @@ class QwenImageTransformerAttentionBlock(WeightModule):
task
=
config
[
"task"
],
mm_type
=
mm_type
,
config
=
config
,
is_offload_buffer
=
is_offload_buffer
,
create_cuda_buffer
=
create_cuda_buffer
,
create_cpu_buffer
=
create_cpu_buffer
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
)
...
...
@@ -78,7 +80,7 @@ class QwenImageTransformerAttentionBlock(WeightModule):
self
.
add_module
(
"img_norm2"
,
LN_WEIGHT_REGISTER
[
"Default"
](
create_cuda_buffer
=
is_offload
_buffer
,
eps
=
1e-6
),
LN_WEIGHT_REGISTER
[
"Default"
](
create_cuda_buffer
=
create_cuda_buffer
,
create_cpu_buffer
=
create_cpu
_buffer
,
eps
=
1e-6
),
)
img_mlp
=
QwenImageFFN
(
block_index
=
block_index
,
...
...
@@ -87,7 +89,8 @@ class QwenImageTransformerAttentionBlock(WeightModule):
task
=
config
[
"task"
],
mm_type
=
mm_type
,
config
=
config
,
is_offload_buffer
=
is_offload_buffer
,
create_cuda_buffer
=
create_cuda_buffer
,
create_cpu_buffer
=
create_cpu_buffer
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
)
...
...
@@ -99,20 +102,21 @@ class QwenImageTransformerAttentionBlock(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mod.1.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mod.1.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
)
self
.
add_module
(
"txt_norm1"
,
LN_WEIGHT_REGISTER
[
"Default"
](
create_cuda_buffer
=
is_offload
_buffer
,
eps
=
1e-6
),
LN_WEIGHT_REGISTER
[
"Default"
](
create_cuda_buffer
=
create_cuda_buffer
,
create_cpu_buffer
=
create_cpu
_buffer
,
eps
=
1e-6
),
)
# Text doesn't need separate attention - it's handled by img_attn joint computation
self
.
add_module
(
"txt_norm2"
,
LN_WEIGHT_REGISTER
[
"Default"
](
create_cuda_buffer
=
is_offload
_buffer
,
eps
=
1e-6
),
LN_WEIGHT_REGISTER
[
"Default"
](
create_cuda_buffer
=
create_cuda_buffer
,
create_cpu_buffer
=
create_cpu
_buffer
,
eps
=
1e-6
),
)
txt_mlp
=
QwenImageFFN
(
block_index
=
block_index
,
...
...
@@ -121,7 +125,8 @@ class QwenImageTransformerAttentionBlock(WeightModule):
task
=
config
[
"task"
],
mm_type
=
mm_type
,
config
=
config
,
is_offload_buffer
=
is_offload_buffer
,
create_cuda_buffer
=
create_cuda_buffer
,
create_cpu_buffer
=
create_cpu_buffer
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
)
...
...
@@ -129,7 +134,7 @@ class QwenImageTransformerAttentionBlock(WeightModule):
class
QwenImageCrossAttention
(
WeightModule
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
is_offload
_buffer
,
lazy_load
,
lazy_load_file
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
create_cuda_buffer
,
create_cpu
_buffer
,
lazy_load
,
lazy_load_file
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
...
...
@@ -146,12 +151,12 @@ class QwenImageCrossAttention(WeightModule):
# norm_q
self
.
add_module
(
"norm_q"
,
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_q.weight"
,
create_cuda_buffer
=
is_offload
_buffer
),
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_q.weight"
,
create_cuda_buffer
=
create_cuda_buffer
,
create_cpu_buffer
=
create_cpu
_buffer
),
)
# norm_k
self
.
add_module
(
"norm_k"
,
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_k.weight"
,
create_cuda_buffer
=
is_offload
_buffer
),
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_k.weight"
,
create_cuda_buffer
=
create_cuda_buffer
,
create_cpu_buffer
=
create_cpu
_buffer
),
)
# to_q
self
.
add_module
(
...
...
@@ -159,7 +164,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_q.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -170,7 +176,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_k.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -181,7 +188,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_v.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -192,7 +200,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_q_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_q_proj.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -203,7 +212,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_k_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_k_proj.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -214,7 +224,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_v_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.add_v_proj.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -225,7 +236,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_out.0.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_out.0.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -236,7 +248,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_add_out.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.attn.to_add_out.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -244,12 +257,12 @@ class QwenImageCrossAttention(WeightModule):
# norm_added_q
self
.
add_module
(
"norm_added_q"
,
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_added_q.weight"
,
create_cuda_buffer
=
is_offload
_buffer
),
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_added_q.weight"
,
create_cuda_buffer
=
create_cuda_buffer
,
create_cpu_buffer
=
create_cpu
_buffer
),
)
# norm_added_k
self
.
add_module
(
"norm_added_k"
,
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_added_k.weight"
,
create_cuda_buffer
=
is_offload
_buffer
),
RMS_WEIGHT_REGISTER
[
"fp32_variance"
](
f
"
{
block_prefix
}
.
{
block_index
}
.attn.norm_added_k.weight"
,
create_cuda_buffer
=
create_cuda_buffer
,
create_cpu_buffer
=
create_cpu
_buffer
),
)
# attn
self
.
add_module
(
"calculate"
,
ATTN_WEIGHT_REGISTER
[
self
.
attn_type
]())
...
...
@@ -266,7 +279,7 @@ class QwenImageCrossAttention(WeightModule):
class
QwenImageFFN
(
WeightModule
):
def
__init__
(
self
,
block_index
,
block_prefix
,
ffn_prefix
,
task
,
mm_type
,
config
,
is_offload
_buffer
,
lazy_load
,
lazy_load_file
):
def
__init__
(
self
,
block_index
,
block_prefix
,
ffn_prefix
,
task
,
mm_type
,
config
,
create_cuda_buffer
,
create_cpu
_buffer
,
lazy_load
,
lazy_load_file
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
...
...
@@ -281,7 +294,8 @@ class QwenImageFFN(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.
{
ffn_prefix
}
.net.0.proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.
{
ffn_prefix
}
.net.0.proj.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -291,7 +305,8 @@ class QwenImageFFN(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.
{
ffn_prefix
}
.net.2.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.
{
ffn_prefix
}
.net.2.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
lightx2v/models/networks/wan/infer/feature_caching/transformer_infer.py
View file @
74eeb429
...
...
@@ -7,6 +7,7 @@ import torch.nn.functional as F
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTaylorCachingTransformerInfer
from
lightx2v.models.networks.wan.infer.offload.transformer_infer
import
WanOffloadTransformerInfer
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
WanTransformerInferCaching
(
WanOffloadTransformerInfer
):
...
...
@@ -56,7 +57,9 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
self
.
accumulated_rel_l1_distance_even
=
0
else
:
rescale_func
=
np
.
poly1d
(
self
.
coefficients
)
self
.
accumulated_rel_l1_distance_even
+=
rescale_func
(((
modulated_inp
-
self
.
previous_e0_even
.
cuda
()).
abs
().
mean
()
/
self
.
previous_e0_even
.
cuda
().
abs
().
mean
()).
cpu
().
item
())
self
.
accumulated_rel_l1_distance_even
+=
rescale_func
(
((
modulated_inp
-
self
.
previous_e0_even
.
to
(
AI_DEVICE
)).
abs
().
mean
()
/
self
.
previous_e0_even
.
to
(
AI_DEVICE
).
abs
().
mean
()).
cpu
().
item
()
)
if
self
.
accumulated_rel_l1_distance_even
<
self
.
teacache_thresh
:
should_calc
=
False
else
:
...
...
@@ -72,7 +75,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
self
.
accumulated_rel_l1_distance_odd
=
0
else
:
rescale_func
=
np
.
poly1d
(
self
.
coefficients
)
self
.
accumulated_rel_l1_distance_odd
+=
rescale_func
(((
modulated_inp
-
self
.
previous_e0_odd
.
cuda
(
)).
abs
().
mean
()
/
self
.
previous_e0_odd
.
cuda
(
).
abs
().
mean
()).
cpu
().
item
())
self
.
accumulated_rel_l1_distance_odd
+=
rescale_func
(((
modulated_inp
-
self
.
previous_e0_odd
.
to
(
AI_DEVICE
)).
abs
().
mean
()
/
self
.
previous_e0_odd
.
to
(
AI_DEVICE
).
abs
().
mean
()).
cpu
().
item
())
if
self
.
accumulated_rel_l1_distance_odd
<
self
.
teacache_thresh
:
should_calc
=
False
else
:
...
...
@@ -149,9 +152,9 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
def
infer_using_cache
(
self
,
x
):
if
self
.
scheduler
.
infer_condition
:
x
.
add_
(
self
.
previous_residual_even
.
cuda
(
))
x
.
add_
(
self
.
previous_residual_even
.
to
(
AI_DEVICE
))
else
:
x
.
add_
(
self
.
previous_residual_odd
.
cuda
(
))
x
.
add_
(
self
.
previous_residual_odd
.
to
(
AI_DEVICE
))
return
x
def
clear
(
self
):
...
...
@@ -1075,7 +1078,7 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching):
def
infer_using_cache
(
self
,
x
):
residual_x
=
self
.
residual_cache
[
self
.
scheduler
.
infer_condition
]
x
.
add_
(
residual_x
.
cuda
(
))
x
.
add_
(
residual_x
.
to
(
AI_DEVICE
))
return
x
def
clear
(
self
):
...
...
lightx2v/models/networks/wan/infer/offload/transformer_infer.py
View file @
74eeb429
import
torch
from
lightx2v.common.offload.manager
import
(
LazyWeightAsyncStreamManager
,
WeightAsyncStreamManager
,
)
from
lightx2v.common.offload.manager
import
WeightAsyncStreamManager
from
lightx2v.models.networks.wan.infer.transformer_infer
import
WanTransformerInfer
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
class
WanOffloadTransformerInfer
(
WanTransformerInfer
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
if
"offload_ratio"
in
self
.
config
:
self
.
offload_ratio
=
self
.
config
[
"offload_ratio"
]
else
:
self
.
offload_ratio
=
1
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
if
offload_granularity
==
"block"
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
infer_func
=
self
.
infer_with_blocks_offload
else
:
self
.
infer_func
=
self
.
infer_with_blocks_lazy_offload
self
.
infer_func
=
self
.
infer_with_blocks_offload
elif
offload_granularity
==
"phase"
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
infer_func
=
self
.
infer_with_phases_offload
else
:
self
.
infer_func
=
self
.
infer_with_phases_lazy_offload
self
.
infer_func
=
self
.
infer_with_phases_offload
self
.
phase_params
=
{
"shift_msa"
:
None
,
"scale_msa"
:
None
,
...
...
@@ -41,121 +31,54 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
self
.
infer_func
=
self
.
infer_without_offload
if
offload_granularity
!=
"model"
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
offload_manager
=
WeightAsyncStreamManager
(
offload_granularity
=
offload_granularity
)
else
:
self
.
offload_manager
=
LazyWeightAsyncStreamManager
(
blocks_num
=
self
.
blocks_num
,
offload_ratio
=
self
.
offload_ratio
,
phases_num
=
self
.
phases_num
,
num_disk_workers
=
self
.
config
.
get
(
"num_disk_workers"
,
2
),
max_memory
=
self
.
config
.
get
(
"max_memory"
,
2
),
offload_gra
=
offload_granularity
,
)
self
.
offload_manager
=
WeightAsyncStreamManager
(
offload_granularity
=
offload_granularity
)
def
infer_with_blocks_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
for
block_idx
in
range
(
len
(
blocks
)):
self
.
block_idx
=
block_idx
if
block_idx
==
0
:
if
self
.
offload_manager
.
need_init_first_buffer
:
self
.
offload_manager
.
init_first_buffer
(
blocks
)
if
block_idx
<
len
(
blocks
)
-
1
:
self
.
offload_manager
.
prefetch_weights
(
block_idx
+
1
,
blocks
)
with
torch
.
cuda
.
stream
(
self
.
offload_manager
.
compute_stream
):
self
.
offload_manager
.
prefetch_weights
((
block_idx
+
1
)
%
len
(
blocks
),
blocks
)
with
torch_device_module
.
stream
(
self
.
offload_manager
.
compute_stream
):
x
=
self
.
infer_block
(
self
.
offload_manager
.
cuda_buffers
[
0
],
x
,
pre_infer_out
)
self
.
offload_manager
.
swap_blocks
()
return
x
def
infer_with_phases_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
for
block_idx
in
range
(
len
(
blocks
)):
self
.
block_idx
=
block_idx
x
=
self
.
infer_phases
(
block_idx
,
blocks
,
x
,
pre_infer_out
,
False
)
if
self
.
clean_cuda_cache
:
del
(
self
.
phase_params
[
"attn_out"
],
self
.
phase_params
[
"y_out"
],
self
.
phase_params
[
"y"
],
)
torch
.
cuda
.
empty_cache
()
if
self
.
clean_cuda_cache
:
self
.
clear_offload_params
(
pre_infer_out
)
return
x
def
infer_with_blocks_lazy_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
self
.
offload_manager
.
prefetch_weights_from_disk
(
blocks
)
for
block_idx
in
range
(
len
(
blocks
)):
self
.
block_idx
=
block_idx
if
block_idx
==
0
:
block
=
self
.
offload_manager
.
pin_memory_buffer
.
get
(
block_idx
)
block
.
to_cuda
()
self
.
offload_manager
.
cuda_buffers
[
0
]
=
(
block_idx
,
block
)
if
block_idx
<
len
(
blocks
)
-
1
:
self
.
offload_manager
.
prefetch_weights
(
block_idx
+
1
,
blocks
)
with
torch
.
cuda
.
stream
(
self
.
offload_manager
.
compute_stream
):
x
=
self
.
infer_block
(
blocks
[
block_idx
],
x
,
pre_infer_out
)
self
.
offload_manager
.
swap_blocks
()
if
block_idx
==
len
(
blocks
)
-
1
:
self
.
offload_manager
.
pin_memory_buffer
.
pop_front
()
self
.
offload_manager
.
_async_prefetch_block
(
blocks
)
if
self
.
clean_cuda_cache
:
del
(
pre_infer_out
.
embed0
,
pre_infer_out
.
freqs
,
pre_infer_out
.
context
,
)
torch
.
cuda
.
empty_cache
()
torch
_device_module
.
empty_cache
()
return
x
def
infer_with_phases_lazy_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
self
.
offload_manager
.
prefetch_weights_from_disk
(
blocks
)
def
infer_with_phases_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
for
block_idx
in
range
(
len
(
blocks
)):
self
.
block_idx
=
block_idx
x
=
self
.
infer_phases
(
block_idx
,
blocks
,
x
,
pre_infer_out
,
True
)
self
.
offload_manager
.
_async_prefetch_block
(
blocks
)
x
=
self
.
infer_phases
(
block_idx
,
blocks
,
x
,
pre_infer_out
)
if
self
.
clean_cuda_cache
:
del
(
self
.
phase_params
[
"attn_out"
],
self
.
phase_params
[
"y_out"
],
self
.
phase_params
[
"y"
],
)
torch
.
cuda
.
empty_cache
()
torch_device_module
.
empty_cache
()
if
self
.
clean_cuda_cache
:
self
.
clear_offload_params
(
pre_infer_out
)
return
x
def
infer_phases
(
self
,
block_idx
,
blocks
,
x
,
pre_infer_out
,
lazy
):
def
infer_phases
(
self
,
block_idx
,
blocks
,
x
,
pre_infer_out
):
for
phase_idx
in
range
(
self
.
phases_num
):
if
block_idx
==
0
and
phase_idx
==
0
:
if
lazy
:
obj_key
=
(
block_idx
,
phase_idx
)
phase
=
self
.
offload_manager
.
pin_memory_buffer
.
get
(
obj_key
)
phase
.
to_cuda
()
self
.
offload_manager
.
cuda_buffers
[
0
]
=
(
obj_key
,
phase
)
else
:
self
.
offload_manager
.
init_first_buffer
(
blocks
)
is_last_phase
=
block_idx
==
len
(
blocks
)
-
1
and
phase_idx
==
self
.
phases_num
-
1
if
not
is_last_phase
:
next_block_idx
=
block_idx
+
1
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
next_phase_idx
=
(
phase_idx
+
1
)
%
self
.
phases_num
self
.
offload_manager
.
prefetch_phase
(
next_block_idx
,
next_phase_idx
,
blocks
)
with
torch
.
cuda
.
stream
(
self
.
offload_manager
.
compute_stream
):
if
self
.
offload_manager
.
need_init_first_buffer
:
self
.
offload_manager
.
init_first_buffer
(
blocks
)
next_block_idx
=
(
block_idx
+
1
)
%
len
(
blocks
)
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
next_phase_idx
=
(
phase_idx
+
1
)
%
self
.
phases_num
self
.
offload_manager
.
prefetch_phase
(
next_block_idx
,
next_phase_idx
,
blocks
)
with
torch_device_module
.
stream
(
self
.
offload_manager
.
compute_stream
):
x
=
self
.
infer_phase
(
phase_idx
,
self
.
offload_manager
.
cuda_buffers
[
phase_idx
],
x
,
pre_infer_out
)
self
.
offload_manager
.
swap_phases
()
...
...
@@ -176,10 +99,7 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
)
=
self
.
pre_process
(
cur_phase
.
modulation
,
pre_infer_out
.
embed0
)
self
.
phase_params
[
"y_out"
]
=
self
.
infer_self_attn
(
cur_phase
,
pre_infer_out
.
grid_sizes
.
tuple
,
x
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
self
.
phase_params
[
"shift_msa"
],
self
.
phase_params
[
"scale_msa"
],
)
...
...
@@ -219,7 +139,6 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
)
del
(
pre_infer_out
.
embed0
,
pre_infer_out
.
freqs
,
pre_infer_out
.
context
,
)
torch
.
cuda
.
empty_cache
()
torch
_device_module
.
empty_cache
()
lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py
100644 → 100755
View file @
74eeb429
...
...
@@ -6,6 +6,7 @@ import torch
from
lightx2v.models.networks.wan.infer.module_io
import
GridOutput
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.utils.envs
import
*
from
lightx2v_platform.base.global_var
import
AI_DEVICE
def
sinusoidal_embedding_1d
(
dim
,
position
):
...
...
@@ -50,7 +51,7 @@ class WanSFPreInfer(WanPreInfer):
rope_params
(
1024
,
2
*
(
d
//
6
)),
],
dim
=
1
,
).
cuda
(
)
).
to
(
AI_DEVICE
)
def
time_embedding
(
self
,
weights
,
embed
):
embed
=
weights
.
time_embedding_0
.
apply
(
embed
)
...
...
lightx2v/models/networks/wan/infer/vace/transformer_infer.py
View file @
74eeb429
...
...
@@ -9,6 +9,7 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer):
self
.
vace_blocks_mapping
=
{
orig_idx
:
seq_idx
for
seq_idx
,
orig_idx
in
enumerate
(
self
.
config
[
"vace_layers"
])}
def
infer
(
self
,
weights
,
pre_infer_out
):
self
.
get_scheduler_values
()
pre_infer_out
.
c
=
self
.
vace_pre_process
(
weights
.
vace_patch_embedding
,
pre_infer_out
.
vace_context
)
self
.
infer_vace_blocks
(
weights
.
vace_blocks
,
pre_infer_out
)
x
=
self
.
infer_main_blocks
(
weights
.
blocks
,
pre_infer_out
)
...
...
@@ -23,11 +24,11 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer):
pre_infer_out
.
adapter_args
[
"hints"
]
=
[]
self
.
infer_state
=
"vace"
if
hasattr
(
self
,
"offload_manager"
):
self
.
offload_manager
.
init_cuda_buffer
(
self
.
vace_offload_block_buffers
,
self
.
vace_offload_phase_buffers
)
self
.
offload_manager
.
init_cuda_buffer
(
self
.
vace_offload_block_
cuda_
buffers
,
self
.
vace_offload_phase_
cuda_
buffers
)
self
.
infer_func
(
vace_blocks
,
pre_infer_out
.
c
,
pre_infer_out
)
self
.
infer_state
=
"base"
if
hasattr
(
self
,
"offload_manager"
):
self
.
offload_manager
.
init_cuda_buffer
(
self
.
offload_block_buffers
,
self
.
offload_phase_buffers
)
self
.
offload_manager
.
init_cuda_buffer
(
self
.
offload_block_
cuda_
buffers
,
self
.
offload_phase_
cuda_
buffers
)
def
post_process
(
self
,
x
,
y
,
c_gate_msa
,
pre_infer_out
):
x
=
super
().
post_process
(
x
,
y
,
c_gate_msa
,
pre_infer_out
)
...
...
lightx2v/models/networks/wan/model.py
View file @
74eeb429
...
...
@@ -47,7 +47,10 @@ class WanModel(CompiledMethodsMixin):
self
.
cpu_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
self
.
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
self
.
model_type
=
model_type
self
.
remove_keys
=
[]
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
if
self
.
lazy_load
:
self
.
remove_keys
.
extend
([
"blocks."
])
if
self
.
config
[
"seq_parallel"
]:
self
.
seq_p_group
=
self
.
config
.
get
(
"device_mesh"
).
get_group
(
mesh_dim
=
"seq_p"
)
else
:
...
...
@@ -146,7 +149,7 @@ class WanModel(CompiledMethodsMixin):
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
if
self
.
config
[
"parallel"
]
:
if
self
.
device
.
type
!=
"cpu"
and
dist
.
is_initialized
()
:
device
=
dist
.
get_rank
()
else
:
device
=
str
(
self
.
device
)
...
...
@@ -169,6 +172,10 @@ class WanModel(CompiledMethodsMixin):
else
:
safetensors_files
=
[
safetensors_path
]
if
self
.
lazy_load
:
assert
len
(
safetensors_files
)
==
1
,
"Only support single safetensors file in lazy load mode"
self
.
lazy_load_path
=
safetensors_files
[
0
]
weight_dict
=
{}
for
file_path
in
safetensors_files
:
if
self
.
config
.
get
(
"adapter_model_path"
,
None
)
is
not
None
:
...
...
@@ -205,6 +212,10 @@ class WanModel(CompiledMethodsMixin):
safetensors_files
=
[
safetensors_path
]
safetensors_path
=
os
.
path
.
dirname
(
safetensors_path
)
if
self
.
lazy_load
:
assert
len
(
safetensors_files
)
==
1
,
"Only support single safetensors file in lazy load mode"
self
.
lazy_load_path
=
safetensors_files
[
0
]
weight_dict
=
{}
for
safetensor_path
in
safetensors_files
:
if
self
.
config
.
get
(
"adapter_model_path"
,
None
)
is
not
None
:
...
...
@@ -237,28 +248,6 @@ class WanModel(CompiledMethodsMixin):
return
weight_dict
def
_load_quant_split_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
# Need rewrite
lazy_load_model_path
=
self
.
dit_quantized_ckpt
logger
.
info
(
f
"Loading splited quant model from
{
lazy_load_model_path
}
"
)
pre_post_weight_dict
=
{}
safetensor_path
=
os
.
path
.
join
(
lazy_load_model_path
,
"non_block.safetensors"
)
with
safe_open
(
safetensor_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
for
k
in
f
.
keys
():
if
f
.
get_tensor
(
k
).
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float
,
]:
if
unified_dtype
or
all
(
s
not
in
k
for
s
in
sensitive_layer
):
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
to
(
GET_DTYPE
()).
to
(
self
.
device
)
else
:
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
to
(
GET_SENSITIVE_DTYPE
()).
to
(
self
.
device
)
else
:
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
to
(
self
.
device
)
return
pre_post_weight_dict
def
_load_gguf_ckpt
(
self
,
gguf_path
):
state_dict
=
load_gguf_sd_ckpt
(
gguf_path
,
to_device
=
self
.
device
)
return
state_dict
...
...
@@ -285,10 +274,7 @@ class WanModel(CompiledMethodsMixin):
weight_dict
=
self
.
_load_ckpt
(
unified_dtype
,
sensitive_layer
)
else
:
# Load quantized weights
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
weight_dict
=
self
.
_load_quant_ckpt
(
unified_dtype
,
sensitive_layer
)
else
:
weight_dict
=
self
.
_load_quant_split_ckpt
(
unified_dtype
,
sensitive_layer
)
weight_dict
=
self
.
_load_quant_ckpt
(
unified_dtype
,
sensitive_layer
)
if
self
.
config
.
get
(
"device_mesh"
)
is
not
None
and
self
.
config
.
get
(
"load_from_rank0"
,
False
):
weight_dict
=
self
.
_load_weights_from_rank0
(
weight_dict
,
is_weight_loader
)
...
...
@@ -302,7 +288,10 @@ class WanModel(CompiledMethodsMixin):
# Initialize weight containers
self
.
pre_weight
=
self
.
pre_weight_class
(
self
.
config
)
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
)
if
self
.
lazy_load
:
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
,
self
.
lazy_load_path
)
else
:
self
.
transformer_weights
=
self
.
transformer_weight_class
(
self
.
config
)
if
not
self
.
_should_init_empty_model
():
self
.
_apply_weights
()
...
...
@@ -383,7 +372,9 @@ class WanModel(CompiledMethodsMixin):
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
if
hasattr
(
self
.
transformer_infer
,
"offload_manager"
):
self
.
transformer_infer
.
offload_manager
.
init_cuda_buffer
(
self
.
transformer_weights
.
offload_block_buffers
,
self
.
transformer_weights
.
offload_phase_buffers
)
self
.
transformer_infer
.
offload_manager
.
init_cuda_buffer
(
self
.
transformer_weights
.
offload_block_cuda_buffers
,
self
.
transformer_weights
.
offload_phase_cuda_buffers
)
if
self
.
lazy_load
:
self
.
transformer_infer
.
offload_manager
.
init_cpu_buffer
(
self
.
transformer_weights
.
offload_block_cpu_buffers
,
self
.
transformer_weights
.
offload_phase_cpu_buffers
)
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
...
...
lightx2v/models/networks/wan/vace_model.py
View file @
74eeb429
...
...
@@ -22,10 +22,15 @@ class WanVaceModel(WanModel):
def
_init_infer
(
self
):
super
().
_init_infer
()
if
hasattr
(
self
.
transformer_infer
,
"offload_manager"
):
self
.
transformer_infer
.
offload_block_buffers
=
self
.
transformer_weights
.
offload_block_buffers
self
.
transformer_infer
.
offload_phase_buffers
=
self
.
transformer_weights
.
offload_phase_buffers
self
.
transformer_infer
.
vace_offload_block_buffers
=
self
.
transformer_weights
.
vace_offload_block_buffers
self
.
transformer_infer
.
vace_offload_phase_buffers
=
self
.
transformer_weights
.
vace_offload_phase_buffers
self
.
transformer_infer
.
offload_block_cuda_buffers
=
self
.
transformer_weights
.
offload_block_cuda_buffers
self
.
transformer_infer
.
offload_phase_cuda_buffers
=
self
.
transformer_weights
.
offload_phase_cuda_buffers
self
.
transformer_infer
.
vace_offload_block_cuda_buffers
=
self
.
transformer_weights
.
vace_offload_block_cuda_buffers
self
.
transformer_infer
.
vace_offload_phase_cuda_buffers
=
self
.
transformer_weights
.
vace_offload_phase_cuda_buffers
if
self
.
lazy_load
:
self
.
transformer_infer
.
offload_block_cpu_buffers
=
self
.
transformer_weights
.
offload_block_cpu_buffers
self
.
transformer_infer
.
offload_phase_cpu_buffers
=
self
.
transformer_weights
.
offload_phase_cpu_buffers
self
.
transformer_infer
.
vace_offload_block_cpu_buffers
=
self
.
transformer_weights
.
vace_offload_block_cpu_buffers
self
.
transformer_infer
.
vace_offload_phase_cpu_buffers
=
self
.
transformer_weights
.
vace_offload_phase_cpu_buffers
def
_init_infer_class
(
self
):
self
.
pre_infer_class
=
WanPreInfer
...
...
lightx2v/models/networks/wan/weights/animate/transformer_weights.py
View file @
74eeb429
...
...
@@ -26,15 +26,19 @@ class WanAnimateTransformerWeights(WanTransformerWeights):
self
.
_add_animate_fuserblock_to_offload_buffers
()
def
_add_animate_fuserblock_to_offload_buffers
(
self
):
if
hasattr
(
self
,
"offload_block_buffers"
)
and
self
.
offload_block_buffers
is
not
None
:
if
hasattr
(
self
,
"offload_block_
cuda_
buffers"
)
and
self
.
offload_block_
cuda_
buffers
is
not
None
:
for
i
in
range
(
self
.
offload_blocks_num
):
self
.
offload_block_buffers
[
i
].
compute_phases
.
append
(
WanAnimateFuserBlock
(
self
.
config
,
0
,
"face_adapter.fuser_blocks"
,
self
.
mm_type
,
is_offload_buffer
=
True
))
elif
hasattr
(
self
,
"offload_phase_buffers"
)
and
self
.
offload_phase_buffers
is
not
None
:
self
.
offload_phase_buffers
.
append
(
WanAnimateFuserBlock
(
self
.
config
,
0
,
"face_adapter.fuser_blocks"
,
self
.
mm_type
,
is_offload_buffer
=
True
))
self
.
offload_block_cuda_buffers
[
i
].
compute_phases
.
append
(
WanAnimateFuserBlock
(
self
.
config
,
0
,
"face_adapter.fuser_blocks"
,
self
.
mm_type
,
create_cuda_buffer
=
True
))
if
self
.
lazy_load
:
self
.
offload_block_cpu_buffers
[
i
].
compute_phases
.
append
(
WanAnimateFuserBlock
(
self
.
config
,
0
,
"face_adapter.fuser_blocks"
,
self
.
mm_type
,
create_cpu_buffer
=
True
))
elif
hasattr
(
self
,
"offload_phase_cuda_buffers"
)
and
self
.
offload_phase_cuda_buffers
is
not
None
:
self
.
offload_phase_cuda_buffers
.
append
(
WanAnimateFuserBlock
(
self
.
config
,
0
,
"face_adapter.fuser_blocks"
,
self
.
mm_type
,
create_cuda_buffer
=
True
))
if
self
.
lazy_load
:
self
.
offload_phase_cpu_buffers
.
append
(
WanAnimateFuserBlock
(
self
.
config
,
0
,
"face_adapter.fuser_blocks"
,
self
.
mm_type
,
create_cpu_buffer
=
True
))
class
WanAnimateFuserBlock
(
WeightModule
):
def
__init__
(
self
,
config
,
block_index
,
block_prefix
,
mm_type
,
is_offload
_buffer
=
False
):
def
__init__
(
self
,
config
,
block_index
,
block_prefix
,
mm_type
,
create_cuda_buffer
=
False
,
create_cpu
_buffer
=
False
):
super
().
__init__
()
self
.
config
=
config
self
.
is_post_adapter
=
True
...
...
@@ -53,7 +57,8 @@ class WanAnimateFuserBlock(WeightModule):
MM_WEIGHT_REGISTER
[
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_kv.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_kv.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
self
.
is_post_adapter
,
...
...
@@ -65,7 +70,8 @@ class WanAnimateFuserBlock(WeightModule):
MM_WEIGHT_REGISTER
[
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_q.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.linear1_q.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
self
.
is_post_adapter
,
...
...
@@ -76,7 +82,8 @@ class WanAnimateFuserBlock(WeightModule):
MM_WEIGHT_REGISTER
[
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.linear2.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.linear2.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
self
.
is_post_adapter
,
...
...
@@ -87,7 +94,8 @@ class WanAnimateFuserBlock(WeightModule):
"q_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"
{
block_prefix
}
.
{
block_index
}
.q_norm.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
self
.
is_post_adapter
,
...
...
@@ -98,7 +106,8 @@ class WanAnimateFuserBlock(WeightModule):
"k_norm"
,
RMS_WEIGHT_REGISTER
[
"sgl-kernel"
](
f
"
{
block_prefix
}
.
{
block_index
}
.k_norm.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
self
.
is_post_adapter
,
...
...
lightx2v/models/networks/wan/weights/audio/transformer_weights.py
View file @
74eeb429
...
...
@@ -19,6 +19,7 @@ class WanAudioTransformerWeights(WanTransformerWeights):
self
.
mm_type
,
self
.
config
,
False
,
False
,
self
.
blocks
[
i
].
lazy_load
,
self
.
blocks
[
i
].
lazy_load_file
,
)
...
...
@@ -27,37 +28,66 @@ class WanAudioTransformerWeights(WanTransformerWeights):
self
.
_add_audio_adapter_ca_to_offload_buffers
()
def
_add_audio_adapter_ca_to_offload_buffers
(
self
):
if
hasattr
(
self
,
"offload_block_buffers"
)
and
self
.
offload_block_buffers
is
not
None
:
if
hasattr
(
self
,
"offload_block_
cuda_
buffers"
)
and
self
.
offload_block_
cuda_
buffers
is
not
None
:
for
i
in
range
(
self
.
offload_blocks_num
):
offload_buffer
=
self
.
offload_block_buffers
[
i
]
offload_buffer
=
self
.
offload_block_
cuda_
buffers
[
i
]
adapter_ca
=
WanAudioAdapterCA
(
block_index
=
i
,
block_prefix
=
f
"ca"
,
task
=
self
.
task
,
mm_type
=
self
.
mm_type
,
config
=
self
.
config
,
is_offload_buffer
=
True
,
create_cuda_buffer
=
True
,
create_cpu_buffer
=
False
,
lazy_load
=
offload_buffer
.
lazy_load
,
lazy_load_file
=
offload_buffer
.
lazy_load_file
,
)
offload_buffer
.
compute_phases
.
append
(
adapter_ca
)
if
self
.
lazy_load
:
offload_buffer
=
self
.
offload_block_cpu_buffers
[
i
]
adapter_ca
=
WanAudioAdapterCA
(
block_index
=
i
,
block_prefix
=
f
"ca"
,
task
=
self
.
task
,
mm_type
=
self
.
mm_type
,
config
=
self
.
config
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
True
,
lazy_load
=
offload_buffer
.
lazy_load
,
lazy_load_file
=
offload_buffer
.
lazy_load_file
,
)
offload_buffer
.
compute_phases
.
append
(
adapter_ca
)
elif
hasattr
(
self
,
"offload_phase_buffers"
)
and
self
.
offload_phase_buffers
is
not
None
:
elif
hasattr
(
self
,
"offload_phase_
cuda_
buffers"
)
and
self
.
offload_phase_
cuda_
buffers
is
not
None
:
adapter_ca
=
WanAudioAdapterCA
(
block_index
=
0
,
block_prefix
=
f
"ca"
,
task
=
self
.
task
,
mm_type
=
self
.
mm_type
,
config
=
self
.
config
,
is_offload_buffer
=
True
,
create_cuda_buffer
=
True
,
create_cpu_buffer
=
False
,
lazy_load
=
self
.
blocks
[
0
].
lazy_load
,
lazy_load_file
=
self
.
blocks
[
0
].
lazy_load_file
,
)
self
.
offload_phase_buffers
.
append
(
adapter_ca
)
self
.
offload_phase_cuda_buffers
.
append
(
adapter_ca
)
if
self
.
lazy_load
:
adapter_ca
=
WanAudioAdapterCA
(
block_index
=
0
,
block_prefix
=
f
"ca"
,
task
=
self
.
task
,
mm_type
=
self
.
mm_type
,
config
=
self
.
config
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
True
,
lazy_load
=
self
.
blocks
[
0
].
lazy_load
,
lazy_load_file
=
self
.
blocks
[
0
].
lazy_load_file
,
)
self
.
offload_phase_cpu_buffers
.
append
(
adapter_ca
)
class
WanAudioAdapterCA
(
WeightModule
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
is_offload
_buffer
,
lazy_load
,
lazy_load_file
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
create_cuda_buffer
,
create_cpu
_buffer
,
lazy_load
,
lazy_load_file
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
...
...
@@ -71,7 +101,8 @@ class WanAudioAdapterCA(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.to_q.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.to_q.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -82,7 +113,8 @@ class WanAudioAdapterCA(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.to_kv.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.to_kv.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -93,7 +125,8 @@ class WanAudioAdapterCA(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
block_index
}
.to_out.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.to_out.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -104,7 +137,8 @@ class WanAudioAdapterCA(WeightModule):
LN_WEIGHT_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
block_index
}
.norm_kv.weight"
,
f
"
{
block_prefix
}
.
{
block_index
}
.norm_kv.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -119,7 +153,8 @@ class WanAudioAdapterCA(WeightModule):
"shift_scale_gate"
,
TENSOR_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
block_index
}
.shift_scale_gate"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
lightx2v/models/networks/wan/weights/matrix_game2/transformer_weights.py
100644 → 100755
View file @
74eeb429
import
os
from
safetensors
import
safe_open
from
lightx2v.common.modules.weight_module
import
WeightModule
,
WeightModuleList
from
lightx2v.models.networks.wan.weights.transformer_weights
import
(
WanFFN
,
...
...
@@ -31,9 +27,9 @@ class WanActionTransformerWeights(WeightModule):
block_list
=
[]
for
i
in
range
(
self
.
blocks_num
):
if
i
in
action_blocks
:
block_list
.
append
(
WanTransformerActionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
"blocks"
))
block_list
.
append
(
WanTransformerActionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
))
else
:
block_list
.
append
(
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
False
,
"blocks"
))
block_list
.
append
(
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
))
self
.
blocks
=
WeightModuleList
(
block_list
)
self
.
add_module
(
"blocks"
,
self
.
blocks
)
...
...
@@ -42,11 +38,6 @@ class WanActionTransformerWeights(WeightModule):
self
.
add_module
(
"head"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"head.head.weight"
,
"head.head.bias"
))
self
.
register_parameter
(
"head_modulation"
,
TENSOR_REGISTER
[
"Default"
](
"head.modulation"
))
def
clear
(
self
):
for
block
in
self
.
blocks
:
for
phase
in
block
.
compute_phases
:
phase
.
clear
()
def
non_block_weights_to_cuda
(
self
):
self
.
norm
.
to_cuda
()
self
.
head
.
to_cuda
()
...
...
@@ -66,34 +57,16 @@ class WanTransformerActionBlock(WeightModule):
self
.
task
=
task
self
.
config
=
config
self
.
quant_method
=
config
.
get
(
"quant_method"
,
None
)
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
if
self
.
lazy_load
:
lazy_load_path
=
os
.
path
.
join
(
self
.
config
[
"dit_quantized_ckpt"
],
f
"block_
{
block_index
}
.safetensors"
)
self
.
lazy_load_file
=
safe_open
(
lazy_load_path
,
framework
=
"pt"
,
device
=
"cpu"
)
else
:
self
.
lazy_load_file
=
None
assert
not
self
.
config
.
get
(
"lazy_load"
,
False
)
self
.
compute_phases
=
WeightModuleList
(
[
WanSelfAttention
(
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
False
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
WanSelfAttention
(
block_index
,
block_prefix
,
task
,
mm_type
,
config
),
WanActionCrossAttention
(
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
WanActionModule
(
block_index
,
...
...
@@ -101,8 +74,6 @@ class WanTransformerActionBlock(WeightModule):
task
,
mm_type
,
config
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
WanFFN
(
block_index
,
...
...
@@ -110,9 +81,6 @@ class WanTransformerActionBlock(WeightModule):
task
,
mm_type
,
config
,
False
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
]
)
...
...
@@ -121,7 +89,7 @@ class WanTransformerActionBlock(WeightModule):
class
WanActionModule
(
WeightModule
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
lazy_load
,
lazy_load_file
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
...
...
@@ -129,9 +97,6 @@ class WanActionModule(WeightModule):
self
.
config
=
config
self
.
quant_method
=
config
.
get
(
"quant_method"
,
None
)
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
attn_rms_type
=
"self_forcing"
self
.
add_module
(
...
...
@@ -139,8 +104,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.keyboard_embed.0.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.keyboard_embed.0.bias"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
)
self
.
add_module
(
...
...
@@ -148,8 +111,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.keyboard_embed.2.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.keyboard_embed.2.bias"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
)
...
...
@@ -158,8 +119,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.proj_keyboard.weight"
,
bias_name
=
None
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
),
)
...
...
@@ -168,8 +127,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.keyboard_attn_kv.weight"
,
bias_name
=
None
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
),
)
...
...
@@ -180,8 +137,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.mouse_attn_q.weight"
,
bias_name
=
None
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
),
)
...
...
@@ -191,8 +146,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.t_qkv.weight"
,
bias_name
=
None
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
),
)
...
...
@@ -201,8 +154,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.proj_mouse.weight"
,
bias_name
=
None
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
),
)
...
...
@@ -211,8 +162,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.mouse_mlp.0.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.mouse_mlp.0.bias"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
)
self
.
add_module
(
...
...
@@ -220,8 +169,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.mouse_mlp.2.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.mouse_mlp.2.bias"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
)
self
.
add_module
(
...
...
@@ -229,22 +176,18 @@ class WanActionModule(WeightModule):
LN_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.mouse_mlp.3.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.action_model.mouse_mlp.3.bias"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
eps
=
1e-6
,
),
)
class
WanActionCrossAttention
(
WeightModule
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
lazy_load
,
lazy_load_file
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
task
=
task
self
.
config
=
config
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
if
self
.
config
.
get
(
"sf_config"
,
False
):
self
.
attn_rms_type
=
"self_forcing"
...
...
@@ -256,8 +199,6 @@ class WanActionCrossAttention(WeightModule):
LN_WEIGHT_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.norm3.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.norm3.bias"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
)
self
.
add_module
(
...
...
@@ -265,8 +206,6 @@ class WanActionCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.q.bias"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
)
self
.
add_module
(
...
...
@@ -274,8 +213,6 @@ class WanActionCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k.bias"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
)
self
.
add_module
(
...
...
@@ -283,8 +220,6 @@ class WanActionCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v.bias"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
)
self
.
add_module
(
...
...
@@ -292,24 +227,18 @@ class WanActionCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.o.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.o.bias"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
)
self
.
add_module
(
"cross_attn_norm_q"
,
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.norm_q.weight"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
)
self
.
add_module
(
"cross_attn_norm_k"
,
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.norm_k.weight"
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
)
self
.
add_module
(
"cross_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"cross_attn_1_type"
]]())
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
74eeb429
import
os
from
safetensors
import
safe_open
from
lightx2v.common.modules.weight_module
import
WeightModule
,
WeightModuleList
...
...
@@ -13,7 +11,7 @@ from lightx2v.utils.registry_factory import (
class
WanTransformerWeights
(
WeightModule
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
,
lazy_load_path
=
None
):
super
().
__init__
()
self
.
blocks_num
=
config
[
"num_layers"
]
self
.
task
=
config
[
"task"
]
...
...
@@ -23,7 +21,27 @@ class WanTransformerWeights(WeightModule):
assert
config
.
get
(
"dit_quantized"
)
is
True
if
config
.
get
(
"do_mm_calib"
,
False
):
self
.
mm_type
=
"Calib"
self
.
blocks
=
WeightModuleList
([
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
)
for
i
in
range
(
self
.
blocks_num
)])
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
if
not
self
.
lazy_load
:
self
.
lazy_load_file
=
None
else
:
self
.
lazy_load_file
=
safe_open
(
lazy_load_path
,
framework
=
"pt"
,
device
=
"cpu"
)
self
.
blocks
=
WeightModuleList
(
[
WanTransformerAttentionBlock
(
block_index
=
i
,
task
=
self
.
task
,
mm_type
=
self
.
mm_type
,
config
=
self
.
config
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
)
for
i
in
range
(
self
.
blocks_num
)
]
)
self
.
register_offload_buffers
(
config
)
self
.
add_module
(
"blocks"
,
self
.
blocks
)
...
...
@@ -36,35 +54,74 @@ class WanTransformerWeights(WeightModule):
if
config
[
"cpu_offload"
]:
if
config
[
"offload_granularity"
]
==
"block"
:
self
.
offload_blocks_num
=
2
self
.
offload_block_buffers
=
WeightModuleList
(
self
.
offload_block_
cuda_
buffers
=
WeightModuleList
(
[
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
True
,
block_index
=
i
,
task
=
self
.
task
,
mm_type
=
self
.
mm_type
,
config
=
self
.
config
,
create_cuda_buffer
=
True
,
create_cpu_buffer
=
False
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
)
for
i
in
range
(
self
.
offload_blocks_num
)
]
)
self
.
add_module
(
"offload_block_buffers"
,
self
.
offload_block_buffers
)
self
.
offload_phase_buffers
=
None
self
.
add_module
(
"offload_block_cuda_buffers"
,
self
.
offload_block_cuda_buffers
)
self
.
offload_phase_cuda_buffers
=
None
if
self
.
lazy_load
:
self
.
offload_blocks_num
=
2
self
.
offload_block_cpu_buffers
=
WeightModuleList
(
[
WanTransformerAttentionBlock
(
block_index
=
i
,
task
=
self
.
task
,
mm_type
=
self
.
mm_type
,
config
=
self
.
config
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
True
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
)
for
i
in
range
(
self
.
offload_blocks_num
)
]
)
self
.
add_module
(
"offload_block_cpu_buffers"
,
self
.
offload_block_cpu_buffers
)
self
.
offload_phase_cpu_buffers
=
None
elif
config
[
"offload_granularity"
]
==
"phase"
:
self
.
offload_phase_buffers
=
WanTransformerAttentionBlock
(
0
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
True
,
self
.
offload_phase_cuda_buffers
=
WanTransformerAttentionBlock
(
block_index
=
0
,
task
=
self
.
task
,
mm_type
=
self
.
mm_type
,
config
=
self
.
config
,
create_cuda_buffer
=
True
,
create_cpu_buffer
=
False
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
).
compute_phases
self
.
add_module
(
"offload_phase_buffers"
,
self
.
offload_phase_buffers
)
self
.
offload_block_buffers
=
None
def
clear
(
self
):
for
block
in
self
.
blocks
:
for
phase
in
block
.
compute_phases
:
phase
.
clear
()
self
.
add_module
(
"offload_phase_cuda_buffers"
,
self
.
offload_phase_cuda_buffers
)
self
.
offload_block_cuda_buffers
=
None
if
self
.
lazy_load
:
self
.
offload_phase_cpu_buffers
=
WanTransformerAttentionBlock
(
block_index
=
0
,
task
=
self
.
task
,
mm_type
=
self
.
mm_type
,
config
=
self
.
config
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
True
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
).
compute_phases
self
.
add_module
(
"offload_phase_cpu_buffers"
,
self
.
offload_phase_cpu_buffers
)
self
.
offload_block_cpu_buffers
=
None
def
non_block_weights_to_cuda
(
self
):
self
.
norm
.
to_cuda
()
...
...
@@ -84,23 +141,23 @@ class WanTransformerAttentionBlock(WeightModule):
task
,
mm_type
,
config
,
is_offload_buffer
=
False
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
block_prefix
=
"blocks"
,
lazy_load
=
False
,
lazy_load_file
=
None
,
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
task
=
task
self
.
config
=
config
self
.
is_offload_buffer
=
is_offload_buffer
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cpu_buffer
=
create_cpu_buffer
self
.
quant_method
=
config
.
get
(
"quant_method"
,
None
)
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
if
self
.
lazy_load
:
lazy_load_path
=
os
.
path
.
join
(
self
.
config
[
"dit_quantized_ckpt"
],
f
"block_
{
block_index
}
.safetensors"
)
self
.
lazy_load_file
=
safe_open
(
lazy_load_path
,
framework
=
"pt"
,
device
=
"cpu"
)
else
:
self
.
lazy_load_file
=
None
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
compute_phases
=
WeightModuleList
(
[
...
...
@@ -110,7 +167,8 @@ class WanTransformerAttentionBlock(WeightModule):
task
,
mm_type
,
config
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -120,7 +178,8 @@ class WanTransformerAttentionBlock(WeightModule):
task
,
mm_type
,
config
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -130,7 +189,8 @@ class WanTransformerAttentionBlock(WeightModule):
task
,
mm_type
,
config
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -148,9 +208,10 @@ class WanSelfAttention(WeightModule):
task
,
mm_type
,
config
,
is_offload_buffer
,
lazy_load
,
lazy_load_file
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
):
super
().
__init__
()
self
.
block_index
=
block_index
...
...
@@ -171,7 +232,8 @@ class WanSelfAttention(WeightModule):
"modulation"
,
TENSOR_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.modulation"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -187,7 +249,8 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.q.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -198,7 +261,8 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.k.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -208,7 +272,8 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.v.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -218,7 +283,8 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.o.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.o.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -227,7 +293,8 @@ class WanSelfAttention(WeightModule):
"self_attn_norm_q"
,
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.norm_q.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -236,7 +303,8 @@ class WanSelfAttention(WeightModule):
"self_attn_norm_k"
,
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.self_attn.norm_k.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -278,7 +346,8 @@ class WanSelfAttention(WeightModule):
"smooth_norm1_weight"
,
TENSOR_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.affine_norm1.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -287,7 +356,8 @@ class WanSelfAttention(WeightModule):
"smooth_norm1_bias"
,
TENSOR_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.affine_norm1.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -302,9 +372,10 @@ class WanCrossAttention(WeightModule):
task
,
mm_type
,
config
,
is_offload_buffer
,
lazy_load
,
lazy_load_file
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
):
super
().
__init__
()
self
.
block_index
=
block_index
...
...
@@ -324,7 +395,8 @@ class WanCrossAttention(WeightModule):
LN_WEIGHT_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.norm3.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.norm3.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -334,7 +406,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.q.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -344,7 +417,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -354,7 +428,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -364,7 +439,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.o.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.o.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -373,7 +449,8 @@ class WanCrossAttention(WeightModule):
"cross_attn_norm_q"
,
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.norm_q.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -382,7 +459,8 @@ class WanCrossAttention(WeightModule):
"cross_attn_norm_k"
,
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.norm_k.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -395,7 +473,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k_img.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.k_img.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -405,7 +484,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v_img.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.v_img.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -414,7 +494,8 @@ class WanCrossAttention(WeightModule):
"cross_attn_norm_k_img"
,
RMS_WEIGHT_REGISTER
[
self
.
attn_rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.cross_attn.norm_k_img.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -430,9 +511,10 @@ class WanFFN(WeightModule):
task
,
mm_type
,
config
,
is_offload_buffer
,
lazy_load
,
lazy_load_file
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
):
super
().
__init__
()
self
.
block_index
=
block_index
...
...
@@ -453,7 +535,8 @@ class WanFFN(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.ffn.0.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.ffn.0.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -463,7 +546,8 @@ class WanFFN(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.ffn.2.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.ffn.2.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -474,7 +558,8 @@ class WanFFN(WeightModule):
"smooth_norm2_weight"
,
TENSOR_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.affine_norm3.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -483,7 +568,8 @@ class WanFFN(WeightModule):
"smooth_norm2_bias"
,
TENSOR_REGISTER
[
"Default"
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.affine_norm3.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
lightx2v/models/networks/wan/weights/vace/transformer_weights.py
View file @
74eeb429
...
...
@@ -15,7 +15,7 @@ class WanVaceTransformerWeights(WanTransformerWeights):
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
register_offload_buffers
(
config
)
self
.
vace_blocks
=
WeightModuleList
(
[
WanVaceTransformerAttentionBlock
(
self
.
config
[
"vace_layers"
][
i
],
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
False
,
"vace_blocks"
)
for
i
in
range
(
len
(
self
.
config
[
"vace_layers"
]))]
[
WanVaceTransformerAttentionBlock
(
self
.
config
[
"vace_layers"
][
i
],
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
False
,
False
,
"vace_blocks"
)
for
i
in
range
(
len
(
self
.
config
[
"vace_layers"
]))]
)
self
.
add_module
(
"vace_blocks"
,
self
.
vace_blocks
)
self
.
add_module
(
...
...
@@ -27,23 +27,17 @@ class WanVaceTransformerWeights(WanTransformerWeights):
super
().
register_offload_buffers
(
config
)
if
config
[
"cpu_offload"
]:
if
config
[
"offload_granularity"
]
==
"block"
:
self
.
vace_offload_block_buffers
=
WeightModuleList
(
self
.
vace_offload_block_
cuda_
buffers
=
WeightModuleList
(
[
WanVaceTransformerAttentionBlock
(
self
.
config
[
"vace_layers"
][
0
],
0
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
True
,
"vace_blocks"
),
WanVaceTransformerAttentionBlock
(
self
.
config
[
"vace_layers"
][
0
],
0
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
True
,
"vace_blocks"
),
WanVaceTransformerAttentionBlock
(
self
.
config
[
"vace_layers"
][
0
],
0
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
True
,
False
,
"vace_blocks"
),
WanVaceTransformerAttentionBlock
(
self
.
config
[
"vace_layers"
][
0
],
0
,
self
.
task
,
self
.
mm_type
,
self
.
config
,
True
,
False
,
"vace_blocks"
),
]
)
self
.
add_module
(
"vace_offload_block_buffers"
,
self
.
vace_offload_block_buffers
)
self
.
vace_offload_phase_buffers
=
None
self
.
add_module
(
"vace_offload_block_
cuda_
buffers"
,
self
.
vace_offload_block_
cuda_
buffers
)
self
.
vace_offload_phase_
cuda_
buffers
=
None
elif
config
[
"offload_granularity"
]
==
"phase"
:
raise
NotImplementedError
def
clear
(
self
):
super
().
clear
()
for
vace_block
in
self
.
vace_blocks
:
for
vace_phase
in
vace_block
.
compute_phases
:
vace_phase
.
clear
()
def
non_block_weights_to_cuda
(
self
):
super
().
non_block_weights_to_cuda
()
self
.
vace_patch_embedding
.
to_cuda
()
...
...
@@ -54,15 +48,16 @@ class WanVaceTransformerWeights(WanTransformerWeights):
class
WanVaceTransformerAttentionBlock
(
WanTransformerAttentionBlock
):
def
__init__
(
self
,
base_block_idx
,
block_index
,
task
,
mm_type
,
config
,
is_offload
_buffer
,
block_prefix
):
super
().
__init__
(
block_index
,
task
,
mm_type
,
config
,
is_offload
_buffer
,
block_prefix
)
def
__init__
(
self
,
base_block_idx
,
block_index
,
task
,
mm_type
,
config
,
create_cuda_buffer
,
create_cpu
_buffer
,
block_prefix
):
super
().
__init__
(
block_index
,
task
,
mm_type
,
config
,
create_cuda_buffer
,
create_cpu
_buffer
,
block_prefix
)
if
base_block_idx
==
0
:
self
.
compute_phases
[
0
].
add_module
(
"before_proj"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.before_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.before_proj.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -73,7 +68,8 @@ class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.after_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.after_proj.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
lightx2v/models/runners/default_runner.py
View file @
74eeb429
...
...
@@ -41,7 +41,8 @@ class DefaultRunner(BaseRunner):
self
.
load_model
()
elif
self
.
config
.
get
(
"lazy_load"
,
False
):
assert
self
.
config
.
get
(
"cpu_offload"
,
False
)
self
.
model
.
set_scheduler
(
self
.
scheduler
)
# set scheduler to model
if
hasattr
(
self
,
"model"
):
self
.
model
.
set_scheduler
(
self
.
scheduler
)
# set scheduler to model
if
self
.
config
[
"task"
]
==
"i2v"
:
self
.
run_input_encoder
=
self
.
_run_input_encoder_local_i2v
elif
self
.
config
[
"task"
]
==
"flf2v"
:
...
...
@@ -184,11 +185,6 @@ class DefaultRunner(BaseRunner):
del
self
.
inputs
self
.
input_info
=
None
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
hasattr
(
self
.
model
.
transformer_infer
,
"weights_stream_mgr"
):
self
.
model
.
transformer_infer
.
weights_stream_mgr
.
clear
()
if
hasattr
(
self
.
model
.
transformer_weights
,
"clear"
):
self
.
model
.
transformer_weights
.
clear
()
self
.
model
.
pre_weight
.
clear
()
del
self
.
model
if
self
.
config
.
get
(
"do_mm_calib"
,
False
):
calib_path
=
os
.
path
.
join
(
os
.
getcwd
(),
"calib.pt"
)
...
...
@@ -279,6 +275,7 @@ class DefaultRunner(BaseRunner):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
model
=
self
.
load_transformer
()
self
.
model
.
set_scheduler
(
self
.
scheduler
)
self
.
model
.
scheduler
.
prepare
(
seed
=
self
.
input_info
.
seed
,
latent_shape
=
self
.
input_info
.
latent_shape
,
image_encoder_output
=
self
.
inputs
[
"image_encoder_output"
])
if
self
.
config
.
get
(
"model_cls"
)
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]:
...
...
lightx2v/models/runners/wan/wan_animate_runner.py
View file @
74eeb429
...
...
@@ -24,6 +24,7 @@ from lightx2v.utils.envs import *
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
load_weights
,
remove_substrings_from_keys
from
lightx2v_platform.base.global_var
import
AI_DEVICE
@
RUNNER_REGISTER
(
"wan2.2_animate"
)
...
...
@@ -182,7 +183,7 @@ class WanAnimateRunner(WanRunner):
],
dim
=
1
,
)
.
cuda
(
)
.
to
(
AI_DEVICE
)
.
unsqueeze
(
0
)
)
mask_pixel_values
=
1
-
mask_pixel_values
...
...
@@ -210,7 +211,7 @@ class WanAnimateRunner(WanRunner):
],
dim
=
1
,
)
.
cuda
(
)
.
to
(
AI_DEVICE
)
.
unsqueeze
(
0
)
)
msk_reft
=
self
.
get_i2v_mask
(
self
.
latent_t
,
self
.
latent_h
,
self
.
latent_w
,
self
.
mask_reft_len
)
...
...
@@ -330,7 +331,7 @@ class WanAnimateRunner(WanRunner):
dtype
=
GET_DTYPE
(),
)
# c t h w
else
:
refer_t_pixel_values
=
self
.
gen_video
[
0
,
:,
-
self
.
config
[
"refert_num"
]
:].
transpose
(
0
,
1
).
clone
().
detach
().
cuda
(
)
# c t h w
refer_t_pixel_values
=
self
.
gen_video
[
0
,
:,
-
self
.
config
[
"refert_num"
]
:].
transpose
(
0
,
1
).
clone
().
detach
().
to
(
AI_DEVICE
)
# c t h w
bg_pixel_values
,
mask_pixel_values
=
None
,
None
if
self
.
config
[
"replace_flag"
]
if
"replace_flag"
in
self
.
config
else
False
:
...
...
@@ -408,8 +409,8 @@ class WanAnimateRunner(WanRunner):
return
model
def
load_encoders
(
self
):
motion_encoder
=
Generator
(
size
=
512
,
style_dim
=
512
,
motion_dim
=
20
).
eval
().
requires_grad_
(
False
).
to
(
GET_DTYPE
()).
cuda
(
)
face_encoder
=
FaceEncoder
(
in_dim
=
512
,
hidden_dim
=
5120
,
num_heads
=
4
).
eval
().
requires_grad_
(
False
).
to
(
GET_DTYPE
()).
cuda
(
)
motion_encoder
=
Generator
(
size
=
512
,
style_dim
=
512
,
motion_dim
=
20
).
eval
().
requires_grad_
(
False
).
to
(
GET_DTYPE
()).
to
(
AI_DEVICE
)
face_encoder
=
FaceEncoder
(
in_dim
=
512
,
hidden_dim
=
5120
,
num_heads
=
4
).
eval
().
requires_grad_
(
False
).
to
(
GET_DTYPE
()).
to
(
AI_DEVICE
)
motion_weight_dict
=
remove_substrings_from_keys
(
load_weights
(
self
.
config
[
"model_path"
],
include_keys
=
[
"motion_encoder"
]),
"motion_encoder."
)
face_weight_dict
=
remove_substrings_from_keys
(
load_weights
(
self
.
config
[
"model_path"
],
include_keys
=
[
"face_encoder"
]),
"face_encoder."
)
motion_encoder
.
load_state_dict
(
motion_weight_dict
)
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
74eeb429
...
...
@@ -435,7 +435,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def
process_single_mask
(
self
,
mask_file
):
mask_img
=
load_image
(
mask_file
)
mask_img
=
TF
.
to_tensor
(
mask_img
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
cuda
(
)
mask_img
=
TF
.
to_tensor
(
mask_img
).
sub_
(
0.5
).
div_
(
0.5
).
unsqueeze
(
0
).
to
(
AI_DEVICE
)
if
mask_img
.
shape
[
1
]
==
3
:
# If it is an RGB three-channel image
mask_img
=
mask_img
[:,
:
1
]
# Only take the first channel
...
...
lightx2v/models/runners/wan/wan_matrix_game2_runner.py
100644 → 100755
View file @
74eeb429
...
...
@@ -13,6 +13,7 @@ from lightx2v.server.metrics import monitor_cli
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
VAEWrapper
:
...
...
@@ -90,8 +91,8 @@ def get_current_action(mode="universal"):
flag
=
1
except
Exception
as
e
:
pass
mouse_cond
=
torch
.
tensor
(
CAMERA_VALUE_MAP
[
idx_mouse
]).
cuda
(
)
keyboard_cond
=
torch
.
tensor
(
KEYBOARD_IDX
[
idx_keyboard
]).
cuda
(
)
mouse_cond
=
torch
.
tensor
(
CAMERA_VALUE_MAP
[
idx_mouse
]).
to
(
AI_DEVICE
)
keyboard_cond
=
torch
.
tensor
(
KEYBOARD_IDX
[
idx_keyboard
]).
to
(
AI_DEVICE
)
elif
mode
==
"gta_drive"
:
print
()
print
(
"-"
*
30
)
...
...
@@ -118,8 +119,8 @@ def get_current_action(mode="universal"):
flag
=
1
except
Exception
as
e
:
pass
mouse_cond
=
torch
.
tensor
(
CAMERA_VALUE_MAP
[
idx_mouse
[
0
]]).
cuda
(
)
keyboard_cond
=
torch
.
tensor
(
KEYBOARD_IDX
[
idx_keyboard
[
0
]]).
cuda
(
)
mouse_cond
=
torch
.
tensor
(
CAMERA_VALUE_MAP
[
idx_mouse
[
0
]]).
to
(
AI_DEVICE
)
keyboard_cond
=
torch
.
tensor
(
KEYBOARD_IDX
[
idx_keyboard
[
0
]]).
to
(
AI_DEVICE
)
elif
mode
==
"templerun"
:
print
()
print
(
"-"
*
30
)
...
...
@@ -142,7 +143,7 @@ def get_current_action(mode="universal"):
flag
=
1
except
Exception
as
e
:
pass
keyboard_cond
=
torch
.
tensor
(
KEYBOARD_IDX
[
idx_keyboard
]).
cuda
(
)
keyboard_cond
=
torch
.
tensor
(
KEYBOARD_IDX
[
idx_keyboard
]).
to
(
AI_DEVICE
)
if
mode
!=
"templerun"
:
return
{
"mouse"
:
mouse_cond
,
"keyboard"
:
keyboard_cond
}
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
74eeb429
...
...
@@ -164,7 +164,7 @@ class WanRunner(DefaultRunner):
if
vae_offload
:
vae_device
=
torch
.
device
(
"cpu"
)
else
:
vae_device
=
torch
.
device
(
self
.
init_device
)
vae_device
=
torch
.
device
(
AI_DEVICE
)
vae_config
=
{
"vae_path"
:
find_torch_model_path
(
self
.
config
,
"vae_path"
,
self
.
vae_name
),
...
...
@@ -178,7 +178,7 @@ class WanRunner(DefaultRunner):
}
if
self
.
config
.
get
(
"use_tae"
,
False
):
tae_path
=
find_torch_model_path
(
self
.
config
,
"tae_path"
,
self
.
tiny_vae_name
)
vae_decoder
=
self
.
tiny_vae_cls
(
vae_path
=
tae_path
,
device
=
self
.
init_device
,
need_scaled
=
self
.
config
.
get
(
"need_scaled"
,
False
)).
to
(
"cuda"
)
vae_decoder
=
self
.
tiny_vae_cls
(
vae_path
=
tae_path
,
device
=
self
.
init_device
,
need_scaled
=
self
.
config
.
get
(
"need_scaled"
,
False
)).
to
(
AI_DEVICE
)
else
:
vae_decoder
=
self
.
vae_cls
(
**
vae_config
)
return
vae_decoder
...
...
lightx2v/models/schedulers/hunyuan_video/posemb_layers.py
100644 → 100755
View file @
74eeb429
...
...
@@ -2,6 +2,8 @@ from typing import List, Tuple, Union
import
torch
from
lightx2v_platform.base.global_var
import
AI_DEVICE
def
_to_tuple
(
x
,
dim
=
2
):
if
isinstance
(
x
,
int
):
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment