Commit 87343386 authored by sandy's avatar sandy Committed by GitHub
Browse files

Fix/wan2 2 vae encode api (#244)

* bugfix:adapt to  5B dit model, derive attention_head_dim from config[dim]

* [Fix] Wan2.2 Vae Encode refactor: drop args parameter and use self.cpu_offload
parent f185da14
...@@ -512,7 +512,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -512,7 +512,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def load_audio_adapter(self): def load_audio_adapter(self):
audio_adapter = AudioAdapter( audio_adapter = AudioAdapter(
attention_head_dim=5120 // self.config["num_heads"], attention_head_dim=self.config["dim"] // self.config["num_heads"],
num_attention_heads=self.config["num_heads"], num_attention_heads=self.config["num_heads"],
base_num_layers=self.config["num_layers"], base_num_layers=self.config["num_layers"],
interval=1, interval=1,
......
...@@ -985,11 +985,11 @@ class Wan2_2_VAE: ...@@ -985,11 +985,11 @@ class Wan2_2_VAE:
self.inv_std = self.inv_std.cuda() self.inv_std = self.inv_std.cuda()
self.scale = [self.mean, self.inv_std] self.scale = [self.mean, self.inv_std]
def encode(self, videos, args): def encode(self, videos):
if hasattr(args, "cpu_offload") and args.cpu_offload: if self.cpu_offload:
self.to_cuda() self.to_cuda()
out = self.model.encode(videos.unsqueeze(0), self.scale).float().squeeze(0) out = self.model.encode(videos.unsqueeze(0), self.scale).float().squeeze(0)
if hasattr(args, "cpu_offload") and args.cpu_offload: if self.cpu_offload:
self.to_cpu() self.to_cpu()
return out return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment