Commit 87343386 authored by sandy's avatar sandy Committed by GitHub
Browse files

Fix/wan2 2 vae encode api (#244)

* bugfix:adapt to  5B dit model, derive attention_head_dim from config[dim]

* [Fix] Wan2.2 Vae Encode refactor: drop args parameter and use self.cpu_offload
parent f185da14
......@@ -512,7 +512,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def load_audio_adapter(self):
audio_adapter = AudioAdapter(
attention_head_dim=5120 // self.config["num_heads"],
attention_head_dim=self.config["dim"] // self.config["num_heads"],
num_attention_heads=self.config["num_heads"],
base_num_layers=self.config["num_layers"],
interval=1,
......
......@@ -985,11 +985,11 @@ class Wan2_2_VAE:
self.inv_std = self.inv_std.cuda()
self.scale = [self.mean, self.inv_std]
def encode(self, videos, args):
if hasattr(args, "cpu_offload") and args.cpu_offload:
def encode(self, videos):
if self.cpu_offload:
self.to_cuda()
out = self.model.encode(videos.unsqueeze(0), self.scale).float().squeeze(0)
if hasattr(args, "cpu_offload") and args.cpu_offload:
if self.cpu_offload:
self.to_cpu()
return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment