"vscode:/vscode.git/clone" did not exist on "2c93b3057dbfd46600674820888570c7c45857cc"
Commit c550409c authored by wangshankun's avatar wangshankun
Browse files

update ref timestep expand

parent 0b755a97
...@@ -104,17 +104,34 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -104,17 +104,34 @@ class WanAudioPreInfer(WanPreInfer):
y = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in y] y = [weights.patch_embedding.apply(u.unsqueeze(0)) for u in y]
# y_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in y]) # y_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in y])
y = [u.flatten(2).transpose(1, 2).squeeze(0) for u in y] y = [u.flatten(2).transpose(1, 2).squeeze(0) for u in y]
ref_seq_lens = torch.tensor([u.size(0) for u in y], dtype=torch.long)
x = [torch.cat([a, b], dim=0) for a, b in zip(x, y)] x = [torch.cat([a, b], dim=0) for a, b in zip(x, y)]
x = torch.stack(x, dim=0) x = torch.stack(x, dim=0)
seq_len = x[0].size(0)
if self.config.model_cls == "wan2.2_audio":
bt = t.size(0)
ref_seq_len = ref_seq_lens[0].item()
t = torch.cat(
[
t,
torch.zeros(
(1, ref_seq_len),
dtype=t.dtype,
device=t.device,
),
],
dim=1,
)
embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten()) embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
# embed = weights.time_embedding_0.apply(embed)
if self.sensitive_layer_dtype != self.infer_dtype: if self.sensitive_layer_dtype != self.infer_dtype:
embed = weights.time_embedding_0.apply(embed.to(self.sensitive_layer_dtype)) embed = weights.time_embedding_0.apply(embed.to(self.sensitive_layer_dtype))
else: else:
embed = weights.time_embedding_0.apply(embed) embed = weights.time_embedding_0.apply(embed)
embed = torch.nn.functional.silu(embed) embed = torch.nn.functional.silu(embed)
embed = weights.time_embedding_2.apply(embed) embed = weights.time_embedding_2.apply(embed)
embed0 = torch.nn.functional.silu(embed) embed0 = torch.nn.functional.silu(embed)
embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim)) embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment