"testing/python/vscode:/vscode.git/clone" did not exist on "7248a810d97ca8ceb999cc0a9e2bf58adc68f263"
Unverified Commit ff263947 authored by Charchit Sharma's avatar Charchit Sharma Committed by GitHub
Browse files

Fix rotary positional embedding dimension mismatch in Wan and SkyReels V2 transformers (#12594)



* Fix rotary positional embedding dimension mismatch in Wan and SkyReels V2 transformers

- Store t_dim, h_dim, w_dim as instance variables in WanRotaryPosEmbed and SkyReelsV2RotaryPosEmbed __init__
- Use stored dimensions in forward() instead of recalculating with different formula
- Fixes inconsistency between init (using // 6) and forward (using // 3)
- Ensures split_sizes matches the dimensions used to create rotary embeddings

* quality fix

---------
Co-authored-by: default avatarCharchit Sharma <charchitsharma@A-267.local>
parent 66e6a021
......@@ -389,6 +389,10 @@ class SkyReelsV2RotaryPosEmbed(nn.Module):
t_dim = attention_head_dim - h_dim - w_dim
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
self.t_dim = t_dim
self.h_dim = h_dim
self.w_dim = w_dim
freqs_cos = []
freqs_sin = []
......@@ -412,11 +416,7 @@ class SkyReelsV2RotaryPosEmbed(nn.Module):
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
split_sizes = [
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
self.attention_head_dim // 3,
self.attention_head_dim // 3,
]
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
......
......@@ -362,6 +362,11 @@ class WanRotaryPosEmbed(nn.Module):
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
self.t_dim = t_dim
self.h_dim = h_dim
self.w_dim = w_dim
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
freqs_cos = []
......@@ -387,11 +392,7 @@ class WanRotaryPosEmbed(nn.Module):
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
split_sizes = [
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
self.attention_head_dim // 3,
self.attention_head_dim // 3,
]
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment