Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
492501d7
Commit
492501d7
authored
Sep 23, 2025
by
gushiqiao
Committed by
GitHub
Sep 23, 2025
Browse files
[Fix] Fix move moe model to cpu bug (#328)
parent
409e5cec
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
8 additions
and
4 deletions
+8
-4
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+2
-2
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+2
-2
lightx2v/models/runners/wan/wan_distill_runner.py
lightx2v/models/runners/wan/wan_distill_runner.py
+2
-0
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+2
-0
No files found.
lightx2v/models/networks/wan/audio_model.py
View file @
492501d7
...
@@ -89,7 +89,7 @@ class WanAudioModel(WanModel):
...
@@ -89,7 +89,7 @@ class WanAudioModel(WanModel):
self
.
enable_compile_mode
(
"_infer_cond_uncond"
)
self
.
enable_compile_mode
(
"_infer_cond_uncond"
)
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
and
"wan2.2_moe"
not
in
self
.
config
.
model_cls
:
self
.
to_cuda
()
self
.
to_cuda
()
elif
self
.
offload_granularity
!=
"model"
:
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cuda
()
self
.
pre_weight
.
to_cuda
()
...
@@ -99,7 +99,7 @@ class WanAudioModel(WanModel):
...
@@ -99,7 +99,7 @@ class WanAudioModel(WanModel):
self
.
start_compile
(
shape
)
self
.
start_compile
(
shape
)
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
and
"wan2.2_moe"
not
in
self
.
config
.
model_cls
:
self
.
to_cpu
()
self
.
to_cpu
()
elif
self
.
offload_granularity
!=
"model"
:
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cpu
()
self
.
pre_weight
.
to_cpu
()
...
...
lightx2v/models/networks/wan/model.py
View file @
492501d7
...
@@ -344,7 +344,7 @@ class WanModel(CompiledMethodsMixin):
...
@@ -344,7 +344,7 @@ class WanModel(CompiledMethodsMixin):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
def
infer
(
self
,
inputs
):
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
and
"wan2.2_moe"
not
in
self
.
config
.
model_cls
:
self
.
to_cuda
()
self
.
to_cuda
()
elif
self
.
offload_granularity
!=
"model"
:
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cuda
()
self
.
pre_weight
.
to_cuda
()
...
@@ -377,7 +377,7 @@ class WanModel(CompiledMethodsMixin):
...
@@ -377,7 +377,7 @@ class WanModel(CompiledMethodsMixin):
self
.
scheduler
.
noise_pred
=
self
.
_infer_cond_uncond
(
inputs
,
infer_condition
=
True
)
self
.
scheduler
.
noise_pred
=
self
.
_infer_cond_uncond
(
inputs
,
infer_condition
=
True
)
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
and
"wan2.2_moe"
not
in
self
.
config
.
model_cls
:
self
.
to_cpu
()
self
.
to_cpu
()
elif
self
.
offload_granularity
!=
"model"
:
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cpu
()
self
.
pre_weight
.
to_cpu
()
...
...
lightx2v/models/runners/wan/wan_distill_runner.py
View file @
492501d7
...
@@ -7,6 +7,7 @@ from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
...
@@ -7,6 +7,7 @@ from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.models.runners.wan.wan_runner
import
MultiModelStruct
,
WanRunner
from
lightx2v.models.runners.wan.wan_runner
import
MultiModelStruct
,
WanRunner
from
lightx2v.models.schedulers.wan.step_distill.scheduler
import
Wan22StepDistillScheduler
,
WanStepDistillScheduler
from
lightx2v.models.schedulers.wan.step_distill.scheduler
import
Wan22StepDistillScheduler
,
WanStepDistillScheduler
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
...
@@ -49,6 +50,7 @@ class MultiDistillModelStruct(MultiModelStruct):
...
@@ -49,6 +50,7 @@ class MultiDistillModelStruct(MultiModelStruct):
self
.
cur_model_index
=
-
1
self
.
cur_model_index
=
-
1
logger
.
info
(
f
"boundary step index:
{
self
.
boundary_step_index
}
"
)
logger
.
info
(
f
"boundary step index:
{
self
.
boundary_step_index
}
"
)
@
ProfilingContext4DebugL2
(
"Swtich models in infer_main costs"
)
def
get_current_model_index
(
self
):
def
get_current_model_index
(
self
):
if
self
.
scheduler
.
step_index
<
self
.
boundary_step_index
:
if
self
.
scheduler
.
step_index
<
self
.
boundary_step_index
:
logger
.
info
(
f
"using - HIGH - noise model at step_index
{
self
.
scheduler
.
step_index
+
1
}
"
)
logger
.
info
(
f
"using - HIGH - noise model at step_index
{
self
.
scheduler
.
step_index
+
1
}
"
)
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
492501d7
...
@@ -25,6 +25,7 @@ from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
...
@@ -25,6 +25,7 @@ from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from
lightx2v.models.video_encoders.hf.wan.vae_2_2
import
Wan2_2_VAE
from
lightx2v.models.video_encoders.hf.wan.vae_2_2
import
Wan2_2_VAE
from
lightx2v.models.video_encoders.hf.wan.vae_tiny
import
Wan2_2_VAE_tiny
,
WanVAE_tiny
from
lightx2v.models.video_encoders.hf.wan.vae_tiny
import
Wan2_2_VAE_tiny
,
WanVAE_tiny
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.utils
import
*
from
lightx2v.utils.utils
import
*
from
lightx2v.utils.utils
import
best_output_size
,
cache_video
from
lightx2v.utils.utils
import
best_output_size
,
cache_video
...
@@ -395,6 +396,7 @@ class MultiModelStruct:
...
@@ -395,6 +396,7 @@ class MultiModelStruct:
self
.
get_current_model_index
()
self
.
get_current_model_index
()
self
.
model
[
self
.
cur_model_index
].
infer
(
inputs
)
self
.
model
[
self
.
cur_model_index
].
infer
(
inputs
)
@
ProfilingContext4DebugL2
(
"Swtich models in infer_main costs"
)
def
get_current_model_index
(
self
):
def
get_current_model_index
(
self
):
if
self
.
scheduler
.
timesteps
[
self
.
scheduler
.
step_index
]
>=
self
.
boundary_timestep
:
if
self
.
scheduler
.
timesteps
[
self
.
scheduler
.
step_index
]
>=
self
.
boundary_timestep
:
logger
.
info
(
f
"using - HIGH - noise model at step_index
{
self
.
scheduler
.
step_index
+
1
}
"
)
logger
.
info
(
f
"using - HIGH - noise model at step_index
{
self
.
scheduler
.
step_index
+
1
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment