Unverified Commit 61d96c3a authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

refactor rotary embedding 3: so it is not on cpu (#9307)

change get_1d_rotary to accept pos as torch tensors
parent 4f495b06
......@@ -545,11 +545,14 @@ def get_1d_rotary_pos_embed(
assert dim % 2 == 0
if isinstance(pos, int):
pos = np.arange(pos)
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # type: ignore # [S]
theta = theta * ntk_factor
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
freqs = torch.outer(t, freqs) # type: ignore # [S, D/2]
freqs = freqs.to(pos.device)
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
......@@ -626,7 +629,7 @@ class FluxPosEmbed(nn.Module):
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.squeeze().float().cpu().numpy()
pos = ids.squeeze().float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
for i in range(n_axes):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment