"graphbolt/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "a2234d60752631d92f46fed9d8be1612c4acbbfd"
Unverified Commit dcf320f2 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

small update on rotary embedding (#9354)



* update

* fix

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 8ba90aa7
...@@ -608,8 +608,11 @@ def get_1d_rotary_pos_embed( ...@@ -608,8 +608,11 @@ def get_1d_rotary_pos_embed(
pos = torch.from_numpy(pos) # type: ignore # [S] pos = torch.from_numpy(pos) # type: ignore # [S]
theta = theta * ntk_factor theta = theta * ntk_factor
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2] freqs = (
freqs = freqs.to(pos.device) 1.0
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
/ linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real: if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox # flux, hunyuan-dit, cogvideox
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment