Unverified Commit 17eb306f authored by Zhengyuan Su (苏政渊)'s avatar Zhengyuan Su (苏政渊) Committed by GitHub
Browse files

[Bugfix] Add contiguous call inside rope kernel wrapper (#17091)


Signed-off-by: default avatar苏政渊 <suzhengyuan@moonshot.cn>
Co-authored-by: default avatar苏政渊 <suzhengyuan@moonshot.cn>
parent 165cb563
...@@ -158,8 +158,13 @@ def rotary_embedding( ...@@ -158,8 +158,13 @@ def rotary_embedding(
cos_sin_cache: torch.Tensor, cos_sin_cache: torch.Tensor,
is_neox: bool, is_neox: bool,
) -> None: ) -> None:
torch.ops._C.rotary_embedding(positions, query, key, head_size, # TODO: Remove this contiguous call when the kernel is updated to support tensor slices
cos_sin_cache, is_neox) query_contiguous = query.contiguous()
key_contiguous = key.contiguous()
torch.ops._C.rotary_embedding(positions, query_contiguous, key_contiguous,
head_size, cos_sin_cache, is_neox)
query.copy_(query_contiguous)
key.copy_(key_contiguous)
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
...@@ -167,9 +172,15 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, ...@@ -167,9 +172,15 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
cos_sin_cache: torch.Tensor, is_neox: bool, cos_sin_cache: torch.Tensor, is_neox: bool,
rot_dim: int, rot_dim: int,
cos_sin_cache_offsets: torch.Tensor) -> None: cos_sin_cache_offsets: torch.Tensor) -> None:
torch.ops._C.batched_rotary_embedding(positions, query, key, head_size, # TODO: Remove this contiguous call when the kernel is updated to support tensor slices
query_contiguous = query.contiguous()
key_contiguous = key.contiguous()
torch.ops._C.batched_rotary_embedding(positions, query_contiguous,
key_contiguous, head_size,
cos_sin_cache, is_neox, rot_dim, cos_sin_cache, is_neox, rot_dim,
cos_sin_cache_offsets) cos_sin_cache_offsets)
query.copy_(query_contiguous)
key.copy_(key_contiguous)
# layer norm ops # layer norm ops
......
...@@ -938,8 +938,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -938,8 +938,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
decode_ql_nope, decode_q_pe = \ decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c) self._q_proj_and_k_up_proj(decode_hs_or_q_c)
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
attn_metadata.decode.input_positions, decode_q_pe.contiguous(), attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe)
decode_k_pe)
if has_prefill: if has_prefill:
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
...@@ -948,8 +947,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -948,8 +947,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
attn_metadata.prefill.input_positions, attn_metadata.prefill.input_positions, prefill_q_pe,
prefill_q_pe.contiguous(), prefill_k_pe) prefill_k_pe)
# write the latent and rope to kv cache # write the latent and rope to kv cache
if kv_cache.numel() > 0: if kv_cache.numel() > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment