Unverified Commit ff263947 authored by Charchit Sharma's avatar Charchit Sharma Committed by GitHub
Browse files

Fix rotary positional embedding dimension mismatch in Wan and SkyReels V2 transformers (#12594)



* Fix rotary positional embedding dimension mismatch in Wan and SkyReels V2 transformers

- Store t_dim, h_dim, w_dim as instance variables in WanRotaryPosEmbed and SkyReelsV2RotaryPosEmbed __init__
- Use stored dimensions in forward() instead of recalculating with different formula
- Fixes inconsistency between init (using // 6) and forward (using // 3)
- Ensures split_sizes matches the dimensions used to create rotary embeddings

* quality fix

---------
Co-authored-by: default avatarCharchit Sharma <charchitsharma@A-267.local>
parent 66e6a021
...@@ -389,6 +389,10 @@ class SkyReelsV2RotaryPosEmbed(nn.Module): ...@@ -389,6 +389,10 @@ class SkyReelsV2RotaryPosEmbed(nn.Module):
t_dim = attention_head_dim - h_dim - w_dim t_dim = attention_head_dim - h_dim - w_dim
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
self.t_dim = t_dim
self.h_dim = h_dim
self.w_dim = w_dim
freqs_cos = [] freqs_cos = []
freqs_sin = [] freqs_sin = []
...@@ -412,11 +416,7 @@ class SkyReelsV2RotaryPosEmbed(nn.Module): ...@@ -412,11 +416,7 @@ class SkyReelsV2RotaryPosEmbed(nn.Module):
p_t, p_h, p_w = self.patch_size p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
split_sizes = [ split_sizes = [self.t_dim, self.h_dim, self.w_dim]
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
self.attention_head_dim // 3,
self.attention_head_dim // 3,
]
freqs_cos = self.freqs_cos.split(split_sizes, dim=1) freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
freqs_sin = self.freqs_sin.split(split_sizes, dim=1) freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
......
...@@ -362,6 +362,11 @@ class WanRotaryPosEmbed(nn.Module): ...@@ -362,6 +362,11 @@ class WanRotaryPosEmbed(nn.Module):
h_dim = w_dim = 2 * (attention_head_dim // 6) h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim t_dim = attention_head_dim - h_dim - w_dim
self.t_dim = t_dim
self.h_dim = h_dim
self.w_dim = w_dim
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
freqs_cos = [] freqs_cos = []
...@@ -387,11 +392,7 @@ class WanRotaryPosEmbed(nn.Module): ...@@ -387,11 +392,7 @@ class WanRotaryPosEmbed(nn.Module):
p_t, p_h, p_w = self.patch_size p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
split_sizes = [ split_sizes = [self.t_dim, self.h_dim, self.w_dim]
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
self.attention_head_dim // 3,
self.attention_head_dim // 3,
]
freqs_cos = self.freqs_cos.split(split_sizes, dim=1) freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
freqs_sin = self.freqs_sin.split(split_sizes, dim=1) freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment