Unverified Commit f389f017 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

Optimize triton_mrope with torch compile (#12112)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
parent caa4819b
......@@ -1424,6 +1424,7 @@ class MRotaryEmbedding(RotaryEmbedding):
else:
return self._forward_native(positions, query, key)
@torch.compile(dynamic=True, backend=get_compiler_backend())
def _forward_triton(
self,
positions: torch.Tensor,
......@@ -1442,6 +1443,7 @@ class MRotaryEmbedding(RotaryEmbedding):
if positions.ndim == 2:
assert self.mrope_section
torch._dynamo.graph_break()
q, k = triton_mrope(
query,
key,
......@@ -1453,6 +1455,7 @@ class MRotaryEmbedding(RotaryEmbedding):
self.mrope_interleaved,
self.is_neox_style,
)
torch._dynamo.graph_break()
return q.reshape(query_shape), k.reshape(key_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment