Unverified Commit bb517fe3 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

[HotFix] Disable torch dynamo for mrope_triton kernel (#12593)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
parent 0e82fd3d
......@@ -1302,6 +1302,31 @@ def triton_mrope(
return q, k
@torch._dynamo.disable()
def triton_mrope_wrapper(
query,
key,
cos,
sin,
mrope_section,
head_size,
rotary_dim,
mrope_interleaved,
is_neox_style,
):
return triton_mrope(
query,
key,
cos,
sin,
mrope_section,
head_size,
rotary_dim,
mrope_interleaved,
is_neox_style,
)
class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections."""
......@@ -1460,8 +1485,7 @@ class MRotaryEmbedding(RotaryEmbedding):
if positions.ndim == 2:
assert self.mrope_section
torch._dynamo.graph_break()
q, k = triton_mrope(
q, k = triton_mrope_wrapper(
query,
key,
cos,
......@@ -1472,7 +1496,6 @@ class MRotaryEmbedding(RotaryEmbedding):
self.mrope_interleaved,
self.is_neox_style,
)
torch._dynamo.graph_break()
return q.reshape(query_shape), k.reshape(key_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment