Commit 4f7c54d8 authored by helloyongyang's avatar helloyongyang
Browse files

fix wan model bug

parent 984cd6c9
......@@ -187,10 +187,10 @@ class WanModel:
self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config)
if self.config.parallel and self.config.parallel.get("cfg_p_size", False) and self.config.parallel.cfg_p_size > 1:
self.infer = self.infer_with_cfg_parallel
if self.config["enable_cfg"] and self.config.parallel and self.config.parallel.get("cfg_p_size", False) and self.config.parallel.cfg_p_size > 1:
self.infer_func = self.infer_with_cfg_parallel
else:
self.infer = self.infer_wo_cfg_parallel
self.infer_func = self.infer_wo_cfg_parallel
def set_scheduler(self, scheduler):
self.scheduler = scheduler
......@@ -208,6 +208,10 @@ class WanModel:
self.post_weight.to_cuda()
self.transformer_weights.to_cuda()
@torch.no_grad()
def infer(self, inputs):
return self.infer_func(inputs)
@torch.no_grad()
def infer_wo_cfg_parallel(self, inputs):
if self.cpu_offload:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment