Unverified Commit e3691988 authored by Rohan Potdar's avatar Rohan Potdar Committed by GitHub
Browse files

[ROCm]: fix aiter rope functionalization (#35533)


Signed-off-by: default avatarRohan138 <rohanpotdar138@gmail.com>
parent 9fa6c68f
......@@ -37,6 +37,14 @@ class FixFunctionalizationPass(VllmInductorPass):
self.nodes_to_remove: list[torch.fx.Node] = []
count = 0
rope_targets = [torch.ops._C.rotary_embedding.default]
if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"):
rope_targets.append(
torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default
)
for node in graph.nodes:
if not is_func(node, auto_functionalized):
continue # Avoid deep if-elif nesting
......@@ -44,7 +52,7 @@ class FixFunctionalizationPass(VllmInductorPass):
kwargs = node.kwargs
at_target = node.args[0]
if at_target == torch.ops._C.rotary_embedding.default:
if at_target in rope_targets:
query = kwargs["query"]
key = kwargs["key"]
getitem_nodes = self.getitem_users(node)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment