Unverified Commit 43b752c3 authored by cjackal's avatar cjackal Committed by GitHub
Browse files

[Llama4] [multimodal] Fix misplaced dtype cast of `cos_sin_cache` in...


[Llama4] [multimodal] Fix misplaced dtype cast of `cos_sin_cache` in `Llama4VisionRotaryEmbedding` (#25889)
Signed-off-by: default avatarcjackal <44624812+cjackal@users.noreply.github.com>
parent cfd302db
...@@ -59,7 +59,9 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding): ...@@ -59,7 +59,9 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
key: Optional[torch.Tensor] = None, key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert key is not None assert key is not None
self._match_cos_sin_cache_dtype(query) # self.cos_sin_cache here is complex tensor so we cannot cast into
# query's dtype directly with self._match_cos_sin_cache_dtype
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
query_ = torch.view_as_complex(query.float().reshape( query_ = torch.view_as_complex(query.float().reshape(
*query.shape[:-1], -1, 2)) *query.shape[:-1], -1, 2))
key_ = torch.view_as_complex(key.float().reshape( key_ = torch.view_as_complex(key.float().reshape(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment