Unverified Commit 7fc53b5d authored by Tolga Cangöz's avatar Tolga Cangöz Committed by GitHub
Browse files

Fix dimensionalities in `apply_rotary_emb` functions' comments (#11717)

Fix dimensionality in `apply_rotary_emb` functions' comments.
parent 0874dd04
...@@ -1199,11 +1199,11 @@ def apply_rotary_emb( ...@@ -1199,11 +1199,11 @@ def apply_rotary_emb(
if use_real_unbind_dim == -1: if use_real_unbind_dim == -1:
# Used for flux, cogvideox, hunyuan-dit # Used for flux, cogvideox, hunyuan-dit
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2: elif use_real_unbind_dim == -2:
# Used for Stable Audio, OmniGen, CogView4 and Cosmos # Used for Stable Audio, OmniGen, CogView4 and Cosmos
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1) x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else: else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
......
...@@ -481,7 +481,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin ...@@ -481,7 +481,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
def apply_rotary_emb(x, freqs): def apply_rotary_emb(x, freqs):
cos, sin = freqs cos, sin = freqs
x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2] x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment