Commit f185da14 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] fix vace bug (#242)

* [Fix] fix bug

* Fix bug

* Update run_wan_vace.sh
parent 3488b187
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict
import torch
......@@ -6,7 +6,6 @@ import torch
@dataclass
class WanPreInferModuleOutput:
# wan base model
embed: torch.Tensor
grid_sizes: torch.Tensor
x: torch.Tensor
......@@ -14,6 +13,4 @@ class WanPreInferModuleOutput:
seq_lens: torch.Tensor
freqs: torch.Tensor
context: torch.Tensor
# wan adapter model
adapter_output: Dict[str, Any] = None
adapter_output: Dict[str, Any] = field(default_factory=dict)
......@@ -153,7 +153,7 @@ class WanVaceRunner(WanRunner):
self.config.target_shape = target_shape
@ProfilingContext("Run VAE Decoder")
def _run_vae_decoder_local(self, latents, generator):
def run_vae_decoder(self, latents, generator):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae_decoder()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment