Unverified Commit 17eb306f authored by Zhengyuan Su (苏政渊)'s avatar Zhengyuan Su (苏政渊) Committed by GitHub
Browse files

[Bugfix] Add contiguous call inside rope kernel wrapper (#17091)


Signed-off-by: default avatar苏政渊 <suzhengyuan@moonshot.cn>
Co-authored-by: default avatar苏政渊 <suzhengyuan@moonshot.cn>
parent 165cb563
......@@ -158,8 +158,13 @@ def rotary_embedding(
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
torch.ops._C.rotary_embedding(positions, query, key, head_size,
cos_sin_cache, is_neox)
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
query_contiguous = query.contiguous()
key_contiguous = key.contiguous()
torch.ops._C.rotary_embedding(positions, query_contiguous, key_contiguous,
head_size, cos_sin_cache, is_neox)
query.copy_(query_contiguous)
key.copy_(key_contiguous)
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
......@@ -167,9 +172,15 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
cos_sin_cache: torch.Tensor, is_neox: bool,
rot_dim: int,
cos_sin_cache_offsets: torch.Tensor) -> None:
torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
query_contiguous = query.contiguous()
key_contiguous = key.contiguous()
torch.ops._C.batched_rotary_embedding(positions, query_contiguous,
key_contiguous, head_size,
cos_sin_cache, is_neox, rot_dim,
cos_sin_cache_offsets)
query.copy_(query_contiguous)
key.copy_(key_contiguous)
# layer norm ops
......
......@@ -938,8 +938,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
attn_metadata.decode.input_positions, decode_q_pe.contiguous(),
decode_k_pe)
attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe)
if has_prefill:
assert attn_metadata.prefill is not None
......@@ -948,8 +947,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
attn_metadata.prefill.input_positions,
prefill_q_pe.contiguous(), prefill_k_pe)
attn_metadata.prefill.input_positions, prefill_q_pe,
prefill_k_pe)
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment