Unverified Commit 74de76c6 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

Revise MRotaryEmbedding's forward (#11859)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: default avatar羽癫 <yudian.zy@antgroup.com>
Co-authored-by: default avatarb8zhong <b8zhong@uwaterloo.ca>
parent 9c0b1eb5
......@@ -1280,7 +1280,7 @@ class MRotaryEmbedding(RotaryEmbedding):
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
@torch.compile(dynamic=True, backend=get_compiler_backend())
def forward_native(
def _forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
......@@ -1340,6 +1340,27 @@ class MRotaryEmbedding(RotaryEmbedding):
query: torch.Tensor,
key: torch.Tensor,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass with optional Triton kernel acceleration.
Args:
positions:
[num_tokens,] (text only) or
[3, num_tokens] (T/H/W positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert positions.ndim == 1 or positions.ndim == 2
if positions.ndim == 2 and self.mrope_section and _is_cuda:
return self._forward_triton(positions, query, key)
else:
return self._forward_native(positions, query, key)
def _forward_triton(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert positions.ndim == 1 or positions.ndim == 2
assert key is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment