Unverified Commit f389f017 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

Optimize triton_mrope with torch compile (#12112)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
parent caa4819b
...@@ -1424,6 +1424,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1424,6 +1424,7 @@ class MRotaryEmbedding(RotaryEmbedding):
else: else:
return self._forward_native(positions, query, key) return self._forward_native(positions, query, key)
@torch.compile(dynamic=True, backend=get_compiler_backend())
def _forward_triton( def _forward_triton(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -1442,6 +1443,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1442,6 +1443,7 @@ class MRotaryEmbedding(RotaryEmbedding):
if positions.ndim == 2: if positions.ndim == 2:
assert self.mrope_section assert self.mrope_section
torch._dynamo.graph_break()
q, k = triton_mrope( q, k = triton_mrope(
query, query,
key, key,
...@@ -1453,6 +1455,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1453,6 +1455,7 @@ class MRotaryEmbedding(RotaryEmbedding):
self.mrope_interleaved, self.mrope_interleaved,
self.is_neox_style, self.is_neox_style,
) )
torch._dynamo.graph_break()
return q.reshape(query_shape), k.reshape(key_shape) return q.reshape(query_shape), k.reshape(key_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment