"vllm/vscode:/vscode.git/clone" did not exist on "e2b31243c092e9f4ade5ffe4bf9a5d5ddae06ca7"
Unverified Commit c719c405 authored by elvischenv's avatar elvischenv Committed by GitHub
Browse files

[Bugfix] Defunctionalize TRTLLM AR+Norm op for avoiding extra clone kernel before it (#29631)


Signed-off-by: default avatarelvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent b08025a8
...@@ -103,6 +103,18 @@ class FixFunctionalizationPass(VllmInductorPass): ...@@ -103,6 +103,18 @@ class FixFunctionalizationPass(VllmInductorPass):
]: ]:
mutated_args = {1: "result"} mutated_args = {1: "result"}
self.defunctionalize(graph, node, mutated_args) self.defunctionalize(graph, node, mutated_args)
elif (
at_target
== torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
):
mutated_args = {
1: "allreduce_in",
2: "residual",
3: "norm_out",
4: "quant_out",
5: "scale_out",
}
self.defunctionalize(graph, node, mutated_args)
# For some reason we need to specify the args for both # For some reason we need to specify the args for both
# silu_and_mul and silu_and_mul_quant. The kwargs # silu_and_mul and silu_and_mul_quant. The kwargs
# pathway gets the wrong answer. # pathway gets the wrong answer.
......
...@@ -75,7 +75,7 @@ def find_op_nodes( ...@@ -75,7 +75,7 @@ def find_op_nodes(
return return
assert isinstance(op, OpOverload) assert isinstance(op, OpOverload)
if not op._schema.is_mutable:
yield from graph.find_nodes(op="call_function", target=op) yield from graph.find_nodes(op="call_function", target=op)
for n in graph.find_nodes(op="call_function", target=auto_functionalized): for n in graph.find_nodes(op="call_function", target=auto_functionalized):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment