"examples/vscode:/vscode.git/clone" did not exist on "58431f102cf39c3c8a569f32d71b2ea8caa461e1"
Unverified Commit 74de76c6 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

Revise MRotaryEmbedding's forward (#11859)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: default avatar羽癫 <yudian.zy@antgroup.com>
Co-authored-by: default avatarb8zhong <b8zhong@uwaterloo.ca>
parent 9c0b1eb5
...@@ -1280,7 +1280,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1280,7 +1280,7 @@ class MRotaryEmbedding(RotaryEmbedding):
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
@torch.compile(dynamic=True, backend=get_compiler_backend()) @torch.compile(dynamic=True, backend=get_compiler_backend())
def forward_native( def _forward_native(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
...@@ -1340,6 +1340,27 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1340,6 +1340,27 @@ class MRotaryEmbedding(RotaryEmbedding):
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass with optional Triton kernel acceleration.
Args:
positions:
[num_tokens,] (text only) or
[3, num_tokens] (T/H/W positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert positions.ndim == 1 or positions.ndim == 2
if positions.ndim == 2 and self.mrope_section and _is_cuda:
return self._forward_triton(positions, query, key)
else:
return self._forward_native(positions, query, key)
def _forward_triton(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
assert positions.ndim == 1 or positions.ndim == 2 assert positions.ndim == 1 or positions.ndim == 2
assert key is not None assert key is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment