Commit 492501d7 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] Fix move moe model to cpu bug (#328)

parent 409e5cec
...@@ -89,7 +89,7 @@ class WanAudioModel(WanModel): ...@@ -89,7 +89,7 @@ class WanAudioModel(WanModel):
self.enable_compile_mode("_infer_cond_uncond") self.enable_compile_mode("_infer_cond_uncond")
if self.cpu_offload: if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == 0: if self.offload_granularity == "model" and self.scheduler.step_index == 0 and "wan2.2_moe" not in self.config.model_cls:
self.to_cuda() self.to_cuda()
elif self.offload_granularity != "model": elif self.offload_granularity != "model":
self.pre_weight.to_cuda() self.pre_weight.to_cuda()
...@@ -99,7 +99,7 @@ class WanAudioModel(WanModel): ...@@ -99,7 +99,7 @@ class WanAudioModel(WanModel):
self.start_compile(shape) self.start_compile(shape)
if self.cpu_offload: if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1: if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1 and "wan2.2_moe" not in self.config.model_cls:
self.to_cpu() self.to_cpu()
elif self.offload_granularity != "model": elif self.offload_granularity != "model":
self.pre_weight.to_cpu() self.pre_weight.to_cpu()
......
...@@ -344,7 +344,7 @@ class WanModel(CompiledMethodsMixin): ...@@ -344,7 +344,7 @@ class WanModel(CompiledMethodsMixin):
@torch.no_grad() @torch.no_grad()
def infer(self, inputs): def infer(self, inputs):
if self.cpu_offload: if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == 0: if self.offload_granularity == "model" and self.scheduler.step_index == 0 and "wan2.2_moe" not in self.config.model_cls:
self.to_cuda() self.to_cuda()
elif self.offload_granularity != "model": elif self.offload_granularity != "model":
self.pre_weight.to_cuda() self.pre_weight.to_cuda()
...@@ -377,7 +377,7 @@ class WanModel(CompiledMethodsMixin): ...@@ -377,7 +377,7 @@ class WanModel(CompiledMethodsMixin):
self.scheduler.noise_pred = self._infer_cond_uncond(inputs, infer_condition=True) self.scheduler.noise_pred = self._infer_cond_uncond(inputs, infer_condition=True)
if self.cpu_offload: if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1: if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1 and "wan2.2_moe" not in self.config.model_cls:
self.to_cpu() self.to_cpu()
elif self.offload_granularity != "model": elif self.offload_granularity != "model":
self.pre_weight.to_cpu() self.pre_weight.to_cpu()
......
...@@ -7,6 +7,7 @@ from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper ...@@ -7,6 +7,7 @@ from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.networks.wan.model import WanModel from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner
from lightx2v.models.schedulers.wan.step_distill.scheduler import Wan22StepDistillScheduler, WanStepDistillScheduler from lightx2v.models.schedulers.wan.step_distill.scheduler import Wan22StepDistillScheduler, WanStepDistillScheduler
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
...@@ -49,6 +50,7 @@ class MultiDistillModelStruct(MultiModelStruct): ...@@ -49,6 +50,7 @@ class MultiDistillModelStruct(MultiModelStruct):
self.cur_model_index = -1 self.cur_model_index = -1
logger.info(f"boundary step index: {self.boundary_step_index}") logger.info(f"boundary step index: {self.boundary_step_index}")
@ProfilingContext4DebugL2("Swtich models in infer_main costs")
def get_current_model_index(self): def get_current_model_index(self):
if self.scheduler.step_index < self.boundary_step_index: if self.scheduler.step_index < self.boundary_step_index:
logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}") logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
......
...@@ -25,6 +25,7 @@ from lightx2v.models.video_encoders.hf.wan.vae import WanVAE ...@@ -25,6 +25,7 @@ from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.models.video_encoders.hf.wan.vae_tiny import Wan2_2_VAE_tiny, WanVAE_tiny from lightx2v.models.video_encoders.hf.wan.vae_tiny import Wan2_2_VAE_tiny, WanVAE_tiny
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import * from lightx2v.utils.utils import *
from lightx2v.utils.utils import best_output_size, cache_video from lightx2v.utils.utils import best_output_size, cache_video
...@@ -395,6 +396,7 @@ class MultiModelStruct: ...@@ -395,6 +396,7 @@ class MultiModelStruct:
self.get_current_model_index() self.get_current_model_index()
self.model[self.cur_model_index].infer(inputs) self.model[self.cur_model_index].infer(inputs)
@ProfilingContext4DebugL2("Swtich models in infer_main costs")
def get_current_model_index(self): def get_current_model_index(self):
if self.scheduler.timesteps[self.scheduler.step_index] >= self.boundary_timestep: if self.scheduler.timesteps[self.scheduler.step_index] >= self.boundary_timestep:
logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}") logger.info(f"using - HIGH - noise model at step_index {self.scheduler.step_index + 1}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment