Unverified Commit 4f495b06 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

rotary embedding refactor 2: update comments, fix dtype for use_real=False (#9312)

fix notes and dtype
parent 40c13fe5
...@@ -514,7 +514,7 @@ def get_1d_rotary_pos_embed( ...@@ -514,7 +514,7 @@ def get_1d_rotary_pos_embed(
linear_factor=1.0, linear_factor=1.0,
ntk_factor=1.0, ntk_factor=1.0,
repeat_interleave_real=True, repeat_interleave_real=True,
freqs_dtype=torch.float32, # torch.float32 (hunyuan, stable audio), torch.float64 (flux) freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
): ):
""" """
Precompute the frequency tensor for complex exponentials (cis) with given dimensions. Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
...@@ -551,15 +551,18 @@ def get_1d_rotary_pos_embed( ...@@ -551,15 +551,18 @@ def get_1d_rotary_pos_embed(
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
freqs = torch.outer(t, freqs) # type: ignore # [S, D/2] freqs = torch.outer(t, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real: if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return freqs_cos, freqs_sin return freqs_cos, freqs_sin
elif use_real: elif use_real:
# stable audio
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin return freqs_cos, freqs_sin
else: else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs).float() # complex64 # [S, D/2] # lumina
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis return freqs_cis
...@@ -590,11 +593,11 @@ def apply_rotary_emb( ...@@ -590,11 +593,11 @@ def apply_rotary_emb(
cos, sin = cos.to(x.device), sin.to(x.device) cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1: if use_real_unbind_dim == -1:
# Use for example in Lumina # Used for flux, cogvideox, hunyuan-dit
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2: elif use_real_unbind_dim == -2:
# Use for example in Stable Audio # Used for Stable Audio
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1) x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else: else:
...@@ -604,6 +607,7 @@ def apply_rotary_emb( ...@@ -604,6 +607,7 @@ def apply_rotary_emb(
return out return out
else: else:
# used for lumina
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2) freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment