Commit 53787c69 authored by helloyongyang's avatar helloyongyang
Browse files

fix parallel compile

parent 067a2b61
...@@ -14,7 +14,6 @@ class WanPostInfer: ...@@ -14,7 +14,6 @@ class WanPostInfer:
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, x, pre_infer_out): def infer(self, x, pre_infer_out):
x = self.unpatchify(x, pre_infer_out.grid_sizes) x = self.unpatchify(x, pre_infer_out.grid_sizes)
......
...@@ -32,7 +32,6 @@ class WanPreInfer: ...@@ -32,7 +32,6 @@ class WanPreInfer:
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, inputs, kv_start=0, kv_end=0): def infer(self, weights, inputs, kv_start=0, kv_end=0):
x = self.scheduler.latents x = self.scheduler.latents
......
...@@ -97,7 +97,6 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -97,7 +97,6 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs) freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
return freqs_i return freqs_i
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, pre_infer_out): def infer(self, weights, pre_infer_out):
x = self.infer_main_blocks(weights, pre_infer_out) x = self.infer_main_blocks(weights, pre_infer_out)
return self.infer_post_blocks(weights, x, pre_infer_out.embed) return self.infer_post_blocks(weights, x, pre_infer_out.embed)
......
...@@ -354,6 +354,7 @@ class WanModel: ...@@ -354,6 +354,7 @@ class WanModel:
self.pre_weight.to_cpu() self.pre_weight.to_cpu()
self.transformer_weights.post_weights_to_cpu() self.transformer_weights.post_weights_to_cpu()
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
@torch.no_grad() @torch.no_grad()
def _infer_cond_uncond(self, inputs, infer_condition=True): def _infer_cond_uncond(self, inputs, infer_condition=True):
self.scheduler.infer_condition = infer_condition self.scheduler.infer_condition = infer_condition
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment