Unverified Commit 461efc57 authored by 王奇勋's avatar 王奇勋 Committed by GitHub
Browse files

[fix code annotation] Adjust the dimensions of the rotary positional embedding. (#8890)



* 2d rotary pos emb dim

* make style

---------
Co-authored-by: default avatarhaofanwang <haofanwang.ai@gmail.com>
parent 3b04cdc8
...@@ -319,12 +319,16 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): ...@@ -319,12 +319,16 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
assert embed_dim % 4 == 0 assert embed_dim % 4 == 0
# use half of dimensions to encode grid_h # use half of dimensions to encode grid_h
emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4) emb_h = get_1d_rotary_pos_embed(
emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4) embed_dim // 2, grid[0].reshape(-1), use_real=use_real
) # (H*W, D/2) if use_real else (H*W, D/4)
emb_w = get_1d_rotary_pos_embed(
embed_dim // 2, grid[1].reshape(-1), use_real=use_real
) # (H*W, D/2) if use_real else (H*W, D/4)
if use_real: if use_real:
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2) cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2) sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
return cos, sin return cos, sin
else: else:
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
...@@ -371,6 +375,8 @@ def get_1d_rotary_pos_embed( ...@@ -371,6 +375,8 @@ def get_1d_rotary_pos_embed(
Returns: Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
""" """
assert dim % 2 == 0
if isinstance(pos, int): if isinstance(pos, int):
pos = np.arange(pos) pos = np.arange(pos)
theta = theta * ntk_factor theta = theta * ntk_factor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment