"src/include/ConstantTensorDescriptor.hpp" did not exist on "2a48812edb1a7c3e280159637fa89b7a0bbfb86b"
Commit 4f7c54d8 authored by helloyongyang's avatar helloyongyang
Browse files

fix wan model bug

parent 984cd6c9
...@@ -187,10 +187,10 @@ class WanModel: ...@@ -187,10 +187,10 @@ class WanModel:
self.pre_infer = self.pre_infer_class(self.config) self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config) self.post_infer = self.post_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config) self.transformer_infer = self.transformer_infer_class(self.config)
if self.config.parallel and self.config.parallel.get("cfg_p_size", False) and self.config.parallel.cfg_p_size > 1: if self.config["enable_cfg"] and self.config.parallel and self.config.parallel.get("cfg_p_size", False) and self.config.parallel.cfg_p_size > 1:
self.infer = self.infer_with_cfg_parallel self.infer_func = self.infer_with_cfg_parallel
else: else:
self.infer = self.infer_wo_cfg_parallel self.infer_func = self.infer_wo_cfg_parallel
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
...@@ -208,6 +208,10 @@ class WanModel: ...@@ -208,6 +208,10 @@ class WanModel:
self.post_weight.to_cuda() self.post_weight.to_cuda()
self.transformer_weights.to_cuda() self.transformer_weights.to_cuda()
@torch.no_grad()
def infer(self, inputs):
return self.infer_func(inputs)
@torch.no_grad() @torch.no_grad()
def infer_wo_cfg_parallel(self, inputs): def infer_wo_cfg_parallel(self, inputs):
if self.cpu_offload: if self.cpu_offload:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment