Unverified Commit 74b5fed4 authored by Aishwarya Badlani's avatar Aishwarya Badlani Committed by GitHub
Browse files

Fix MPS compatibility in get_1d_sincos_pos_embed_from_grid #12432 (#12449)



* Fix MPS compatibility in get_1d_sincos_pos_embed_from_grid #12432

* Fix trailing whitespace in docstring

* Apply style fixes

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 85eb5056
......@@ -319,13 +319,17 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False):
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False, dtype=None):
"""
This function generates 1D positional embeddings from a grid.
Args:
embed_dim (`int`): The embedding dimension `D`
pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
output_type (`str`, *optional*, defaults to `"np"`): Output type. Use `"pt"` for PyTorch tensors.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip sine and cosine embeddings.
dtype (`torch.dtype`, *optional*): Data type for frequency calculations. If `None`, defaults to
`torch.float32` on MPS devices (which don't support `torch.float64`) and `torch.float64` on other devices.
Returns:
`torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
......@@ -341,7 +345,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
# Auto-detect appropriate dtype if not specified
if dtype is None:
dtype = torch.float32 if pos.device.type == "mps" else torch.float64
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=dtype)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment