Unverified Commit bb517fe3 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

[HotFix] Disable torch dynamo for mrope_triton kernel (#12593)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
parent 0e82fd3d
...@@ -1302,6 +1302,31 @@ def triton_mrope( ...@@ -1302,6 +1302,31 @@ def triton_mrope(
return q, k return q, k
@torch._dynamo.disable()
def triton_mrope_wrapper(
query,
key,
cos,
sin,
mrope_section,
head_size,
rotary_dim,
mrope_interleaved,
is_neox_style,
):
return triton_mrope(
query,
key,
cos,
sin,
mrope_section,
head_size,
rotary_dim,
mrope_interleaved,
is_neox_style,
)
class MRotaryEmbedding(RotaryEmbedding): class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections.""" """Rotary Embedding with Multimodal Sections."""
...@@ -1460,8 +1485,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1460,8 +1485,7 @@ class MRotaryEmbedding(RotaryEmbedding):
if positions.ndim == 2: if positions.ndim == 2:
assert self.mrope_section assert self.mrope_section
torch._dynamo.graph_break() q, k = triton_mrope_wrapper(
q, k = triton_mrope(
query, query,
key, key,
cos, cos,
...@@ -1472,7 +1496,6 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1472,7 +1496,6 @@ class MRotaryEmbedding(RotaryEmbedding):
self.mrope_interleaved, self.mrope_interleaved,
self.is_neox_style, self.is_neox_style,
) )
torch._dynamo.graph_break()
return q.reshape(query_shape), k.reshape(key_shape) return q.reshape(query_shape), k.reshape(key_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment