Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
7a111e37
Commit
7a111e37
authored
Sep 24, 2025
by
gushiqiao
Committed by
GitHub
Sep 24, 2025
Browse files
[Fix] Fix moe offload bug (#330)
parent
c6be06a6
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
55 additions
and
49 deletions
+55
-49
configs/wan22/wan_distill_moe_flf2v.json
configs/wan22/wan_distill_moe_flf2v.json
+1
-1
configs/wan22/wan_distill_moe_flf2v_fp8.json
configs/wan22/wan_distill_moe_flf2v_fp8.json
+1
-1
configs/wan22/wan_distill_moe_flf2v_int8.json
configs/wan22/wan_distill_moe_flf2v_int8.json
+1
-1
configs/wan22/wan_moe_i2v_4090.json
configs/wan22/wan_moe_i2v_4090.json
+4
-2
configs/wan22/wan_moe_i2v_distill.json
configs/wan22/wan_moe_i2v_distill.json
+4
-11
configs/wan22/wan_moe_i2v_distill_quant.json
configs/wan22/wan_moe_i2v_distill_quant.json
+3
-3
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+41
-30
No files found.
configs/wan22/wan_distill_moe_flf2v.json
View file @
7a111e37
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
"sample_shift"
:
16
,
"sample_shift"
:
16
,
"enable_cfg"
:
false
,
"enable_cfg"
:
false
,
"cpu_offload"
:
true
,
"cpu_offload"
:
true
,
"offload_granularity"
:
"
model
"
,
"offload_granularity"
:
"
block
"
,
"t5_cpu_offload"
:
false
,
"t5_cpu_offload"
:
false
,
"vae_cpu_offload"
:
false
,
"vae_cpu_offload"
:
false
,
"use_image_encoder"
:
false
,
"use_image_encoder"
:
false
,
...
...
configs/wan22/wan_distill_moe_flf2v_fp8.json
View file @
7a111e37
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
"sample_shift"
:
16
,
"sample_shift"
:
16
,
"enable_cfg"
:
false
,
"enable_cfg"
:
false
,
"cpu_offload"
:
true
,
"cpu_offload"
:
true
,
"offload_granularity"
:
"
model
"
,
"offload_granularity"
:
"
block
"
,
"t5_cpu_offload"
:
false
,
"t5_cpu_offload"
:
false
,
"vae_cpu_offload"
:
false
,
"vae_cpu_offload"
:
false
,
"use_image_encoder"
:
false
,
"use_image_encoder"
:
false
,
...
...
configs/wan22/wan_distill_moe_flf2v_int8.json
View file @
7a111e37
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
"sample_shift"
:
16
,
"sample_shift"
:
16
,
"enable_cfg"
:
false
,
"enable_cfg"
:
false
,
"cpu_offload"
:
true
,
"cpu_offload"
:
true
,
"offload_granularity"
:
"
model
"
,
"offload_granularity"
:
"
block
"
,
"t5_cpu_offload"
:
false
,
"t5_cpu_offload"
:
false
,
"vae_cpu_offload"
:
false
,
"vae_cpu_offload"
:
false
,
"use_image_encoder"
:
false
,
"use_image_encoder"
:
false
,
...
...
configs/wan22/wan_moe_i2v_4090.json
View file @
7a111e37
...
@@ -10,9 +10,11 @@
...
@@ -10,9 +10,11 @@
"seed"
:
42
,
"seed"
:
42
,
"sample_guide_scale"
:
[
3.5
,
3.5
],
"sample_guide_scale"
:
[
3.5
,
3.5
],
"sample_shift"
:
5.0
,
"sample_shift"
:
5.0
,
"enable_cfg"
:
tru
e
,
"enable_cfg"
:
fals
e
,
"cpu_offload"
:
true
,
"cpu_offload"
:
true
,
"offload_granularity"
:
"phase"
,
"offload_granularity"
:
"phase"
,
"boundary"
:
0.900
,
"boundary"
:
0.900
,
"use_image_encoder"
:
false
"use_image_encoder"
:
false
,
"boundary_step_index"
:
2
,
"denoising_step_list"
:
[
1000
,
750
,
500
,
250
]
}
}
configs/wan22/wan_moe_i2v_distill.json
View file @
7a111e37
...
@@ -2,8 +2,8 @@
...
@@ -2,8 +2,8 @@
"infer_steps"
:
4
,
"infer_steps"
:
4
,
"target_video_length"
:
81
,
"target_video_length"
:
81
,
"text_len"
:
512
,
"text_len"
:
512
,
"target_height"
:
48
0
,
"target_height"
:
72
0
,
"target_width"
:
832
,
"target_width"
:
1280
,
"self_attn_1_type"
:
"flash_attn3"
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
...
@@ -12,17 +12,10 @@
...
@@ -12,17 +12,10 @@
"sample_shift"
:
5.0
,
"sample_shift"
:
5.0
,
"enable_cfg"
:
false
,
"enable_cfg"
:
false
,
"cpu_offload"
:
true
,
"cpu_offload"
:
true
,
"offload_granularity"
:
"
model
"
,
"offload_granularity"
:
"
block
"
,
"t5_cpu_offload"
:
false
,
"t5_cpu_offload"
:
false
,
"vae_cpu_offload"
:
false
,
"vae_cpu_offload"
:
false
,
"use_image_encoder"
:
false
,
"use_image_encoder"
:
false
,
"boundary_step_index"
:
2
,
"boundary_step_index"
:
2
,
"denoising_step_list"
:
[
1000
,
750
,
500
,
250
],
"denoising_step_list"
:
[
1000
,
750
,
500
,
250
]
"lora_configs"
:
[
{
"name"
:
"low_noise_model"
,
"path"
:
"Wan2.1-I2V-14B-480P/loras/Wan21_I2V_14B_lightx2v_cfg_step_distill_lora_rank64.safetensors"
,
"strength"
:
1.0
}
]
}
}
configs/wan22/wan_moe_i2v_distill_quant.json
View file @
7a111e37
...
@@ -2,8 +2,8 @@
...
@@ -2,8 +2,8 @@
"infer_steps"
:
4
,
"infer_steps"
:
4
,
"target_video_length"
:
81
,
"target_video_length"
:
81
,
"text_len"
:
512
,
"text_len"
:
512
,
"target_height"
:
48
0
,
"target_height"
:
72
0
,
"target_width"
:
832
,
"target_width"
:
1280
,
"self_attn_1_type"
:
"flash_attn3"
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
"sample_shift"
:
5.0
,
"sample_shift"
:
5.0
,
"enable_cfg"
:
false
,
"enable_cfg"
:
false
,
"cpu_offload"
:
true
,
"cpu_offload"
:
true
,
"offload_granularity"
:
"
model
"
,
"offload_granularity"
:
"
block
"
,
"t5_cpu_offload"
:
false
,
"t5_cpu_offload"
:
false
,
"vae_cpu_offload"
:
false
,
"vae_cpu_offload"
:
false
,
"use_image_encoder"
:
false
,
"use_image_encoder"
:
false
,
...
...
lightx2v/common/ops/mm/mm_weight.py
View file @
7a111e37
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
import
torch
import
torch
import
torch.distributed
as
dist
from
loguru
import
logger
from
loguru
import
logger
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
...
@@ -64,18 +63,25 @@ class MMWeightTemplate(metaclass=ABCMeta):
...
@@ -64,18 +63,25 @@ class MMWeightTemplate(metaclass=ABCMeta):
self
.
config
=
config
self
.
config
=
config
def
to_cuda
(
self
,
non_blocking
=
False
):
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
weight
.
cuda
(
non_blocking
=
non_blocking
)
self
.
weight
=
self
.
pin_
weight
.
cuda
(
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"weight_scale"
):
if
hasattr
(
self
,
"
pin_
weight_scale"
):
self
.
weight_scale
=
self
.
weight_scale
.
cuda
(
non_blocking
=
non_blocking
)
self
.
weight_scale
=
self
.
pin_
weight_scale
.
cuda
(
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
if
hasattr
(
self
,
"
pin_
bias"
)
and
self
.
pin_
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
cuda
(
non_blocking
=
non_blocking
)
self
.
bias
=
self
.
pin_
bias
.
cuda
(
non_blocking
=
non_blocking
)
def
to_cpu
(
self
,
non_blocking
=
False
):
def
to_cpu
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
weight
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"pin_weight"
):
if
hasattr
(
self
,
"weight_scale"
):
self
.
weight
=
self
.
pin_weight
.
copy_
(
self
.
weight
,
non_blocking
=
non_blocking
).
cpu
()
self
.
weight_scale
=
self
.
weight_scale
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"weight_scale_name"
):
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
self
.
weight_scale
=
self
.
pin_weight_scale
.
copy_
(
self
.
weight_scale
,
non_blocking
=
non_blocking
).
cpu
()
self
.
bias
=
self
.
bias
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
pin_bias
.
copy_
(
self
.
bias
,
non_blocking
=
non_blocking
).
cpu
()
else
:
self
.
weight
=
self
.
weight
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"weight_scale"
):
self
.
weight_scale
=
self
.
weight_scale
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
@
MM_WEIGHT_REGISTER
(
"Default"
)
@
MM_WEIGHT_REGISTER
(
"Default"
)
...
@@ -92,16 +98,16 @@ class MMWeight(MMWeightTemplate):
...
@@ -92,16 +98,16 @@ class MMWeight(MMWeightTemplate):
elif
device
.
type
==
"cpu"
:
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
t
().
shape
weight_shape
=
weight_dict
[
self
.
weight_name
].
t
().
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_
weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
weight
.
copy_
(
weight_dict
[
self
.
weight_name
].
t
())
self
.
pin_
weight
.
copy_
(
weight_dict
[
self
.
weight_name
].
t
())
if
self
.
bias_name
is
not
None
:
if
self
.
bias_name
is
not
None
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_
bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
self
.
pin_
bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
else
:
self
.
bias
=
None
self
.
pin_
bias
=
None
del
weight_dict
[
self
.
weight_name
]
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
...
@@ -176,10 +182,13 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -176,10 +182,13 @@ class MMWeightQuantTemplate(MMWeightTemplate):
if
not
self
.
lazy_load
:
if
not
self
.
lazy_load
:
self
.
load_func
(
weight_dict
)
self
.
load_func
(
weight_dict
)
if
self
.
weight_need_transpose
:
if
self
.
weight_need_transpose
:
self
.
weight
=
self
.
weight
.
t
()
if
hasattr
(
self
,
"weight"
):
self
.
weight
=
self
.
weight
.
t
()
elif
hasattr
(
self
,
"pin_weight"
):
self
.
pin_weight
=
self
.
pin_weight
.
t
()
def
clear
(
self
):
def
clear
(
self
):
attrs
=
[
"weight"
,
"weight_scale"
,
"bias"
,
"pin
ned
_weight"
,
"pin
ned
_weight_scale"
,
"pin
ned
_bias"
]
attrs
=
[
"weight"
,
"weight_scale"
,
"bias"
,
"pin_weight"
,
"pin_weight_scale"
,
"pin_bias"
]
for
attr
in
attrs
:
for
attr
in
attrs
:
if
hasattr
(
self
,
attr
):
if
hasattr
(
self
,
attr
):
delattr
(
self
,
attr
)
delattr
(
self
,
attr
)
...
@@ -198,15 +207,14 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -198,15 +207,14 @@ class MMWeightQuantTemplate(MMWeightTemplate):
elif
device
.
type
==
"cpu"
:
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_
weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
self
.
pin_
weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
weight_scale_shape
=
weight_dict
[
self
.
weight_scale_name
].
shape
weight_scale_shape
=
weight_dict
[
self
.
weight_scale_name
].
shape
weight_scale_dtype
=
torch
.
float
weight_scale_dtype
=
torch
.
float
self
.
weight_scale
=
torch
.
empty
(
weight_scale_shape
,
pin_memory
=
True
,
dtype
=
weight_scale_dtype
)
self
.
pin_weight_scale
=
torch
.
empty
(
weight_scale_shape
,
pin_memory
=
True
,
dtype
=
weight_scale_dtype
)
self
.
weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
if
dist
.
is_initialized
():
del
weight_dict
[
self
.
weight_name
]
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
...
@@ -227,12 +235,13 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -227,12 +235,13 @@ class MMWeightQuantTemplate(MMWeightTemplate):
elif
device
.
type
==
"cpu"
:
elif
device
.
type
==
"cpu"
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_
bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
self
.
pin_
bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
self
.
pin_bias
=
None
def
load_int8_perchannel_sym
(
self
,
weight_dict
):
def
load_int8_perchannel_sym
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
...
@@ -251,12 +260,13 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -251,12 +260,13 @@ class MMWeightQuantTemplate(MMWeightTemplate):
elif
device
.
type
==
"cpu"
:
elif
device
.
type
==
"cpu"
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_
bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
self
.
pin_
bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
self
.
pin_bias
=
None
def
load_fp8_perblock128_sym
(
self
,
weight_dict
):
def
load_fp8_perblock128_sym
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
...
@@ -272,12 +282,13 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -272,12 +282,13 @@ class MMWeightQuantTemplate(MMWeightTemplate):
elif
device
.
type
==
"cpu"
:
elif
device
.
type
==
"cpu"
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_
bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
self
.
pin_
bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
self
.
pin_bias
=
None
def
per_block_cast_to_fp8
(
self
,
x
):
def
per_block_cast_to_fp8
(
self
,
x
):
assert
x
.
dim
()
==
2
assert
x
.
dim
()
==
2
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment