Commit f185da14 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] fix vace bug (#242)

* [Fix] fix bug

* Fix bug

* Update run_wan_vace.sh
parent 3488b187
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Any, Dict from typing import Any, Dict
import torch import torch
...@@ -6,7 +6,6 @@ import torch ...@@ -6,7 +6,6 @@ import torch
@dataclass @dataclass
class WanPreInferModuleOutput: class WanPreInferModuleOutput:
# wan base model
embed: torch.Tensor embed: torch.Tensor
grid_sizes: torch.Tensor grid_sizes: torch.Tensor
x: torch.Tensor x: torch.Tensor
...@@ -14,6 +13,4 @@ class WanPreInferModuleOutput: ...@@ -14,6 +13,4 @@ class WanPreInferModuleOutput:
seq_lens: torch.Tensor seq_lens: torch.Tensor
freqs: torch.Tensor freqs: torch.Tensor
context: torch.Tensor context: torch.Tensor
adapter_output: Dict[str, Any] = field(default_factory=dict)
# wan adapter model
adapter_output: Dict[str, Any] = None
...@@ -153,7 +153,7 @@ class WanVaceRunner(WanRunner): ...@@ -153,7 +153,7 @@ class WanVaceRunner(WanRunner):
self.config.target_shape = target_shape self.config.target_shape = target_shape
@ProfilingContext("Run VAE Decoder") @ProfilingContext("Run VAE Decoder")
def _run_vae_decoder_local(self, latents, generator): def run_vae_decoder(self, latents, generator):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae_decoder() self.vae_decoder = self.load_vae_decoder()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment