Unverified Commit 648951a9 authored by Mayank Ketkar's avatar Mayank Ketkar Committed by GitHub
Browse files

[Bugfix] Fix benchmark_fused_collective crash on CustomOp init (#34665)


Signed-off-by: default avatarMayank Ketkar <mketkar@zoox.com>
Signed-off-by: default avatarMayank Ketkar <mayket04@gmail.com>
Co-authored-by: default avatarMayank Ketkar <mketkar@zoox.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent f72061a1
...@@ -408,18 +408,18 @@ def run_benchmarks( ...@@ -408,18 +408,18 @@ def run_benchmarks(
rms_eps = 1e-6 rms_eps = 1e-6
results = {} results = {}
vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
use_oneshot_options = [False] if no_oneshot else [True, False] use_oneshot_options = [False] if no_oneshot else [True, False]
# Create RMSNorm and QuantFP8 layers once for native benchmarks
if "none" in quant_modes: if "none" in quant_modes:
# Standard AllReduce + RMSNorm # Standard AllReduce + RMSNorm
# Re-create VllmFusedAllreduce per config so CustomOp binds the
# correct forward method (native vs custom kernel).
for custom_op in ["-rms_norm", "+rms_norm"]: for custom_op in ["-rms_norm", "+rms_norm"]:
with set_current_vllm_config( with set_current_vllm_config(
VllmConfig(compilation_config=CompilationConfig(custom_ops=[custom_op])) VllmConfig(compilation_config=CompilationConfig(custom_ops=[custom_op]))
): ):
try: try:
vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
suffix = ( suffix = (
"_custom_rms_norm" if "+" in custom_op else "_native_rms_norm" "_custom_rms_norm" if "+" in custom_op else "_native_rms_norm"
) )
...@@ -438,6 +438,7 @@ def run_benchmarks( ...@@ -438,6 +438,7 @@ def run_benchmarks(
VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"])) VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"]))
): ):
try: try:
vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
standard_allreduce_rmsnorm_native_compiled = torch.compile( standard_allreduce_rmsnorm_native_compiled = torch.compile(
vllm_fused_allreduce.allreduce_rmsnorm, vllm_fused_allreduce.allreduce_rmsnorm,
fullgraph=True, fullgraph=True,
...@@ -482,7 +483,7 @@ def run_benchmarks( ...@@ -482,7 +483,7 @@ def run_benchmarks(
"_custom_rms_norm" if "+" in rms_norm_custom_op else "_native_rms_norm" "_custom_rms_norm" if "+" in rms_norm_custom_op else "_native_rms_norm"
) )
for quant_fp8_custom_op in ["-quant_fp8", "+quant_fp8"]: for quant_fp8_custom_op in ["-quant_fp8", "+quant_fp8"]:
suffix += ( op_suffix = suffix + (
"_custom_quant_fp8" "_custom_quant_fp8"
if "+" in quant_fp8_custom_op if "+" in quant_fp8_custom_op
else "_native_quant_fp8" else "_native_quant_fp8"
...@@ -495,16 +496,17 @@ def run_benchmarks( ...@@ -495,16 +496,17 @@ def run_benchmarks(
) )
): ):
try: try:
vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
time_ms = benchmark_operation( time_ms = benchmark_operation(
vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant, vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant,
input_tensor, input_tensor,
residual=residual, residual=residual,
scale_factor=scale_fp8, scale_factor=scale_fp8,
) )
results[f"standard_allreduce{suffix}"] = time_ms results[f"standard_allreduce{op_suffix}"] = time_ms
except Exception as e: except Exception as e:
logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e) logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e)
results[f"standard_allreduce{suffix}"] = float("inf") results[f"standard_allreduce{op_suffix}"] = float("inf")
# Standard AllReduce + RMSNorm + FP8 Quant Native Compiled # Standard AllReduce + RMSNorm + FP8 Quant Native Compiled
with set_current_vllm_config( with set_current_vllm_config(
...@@ -515,6 +517,7 @@ def run_benchmarks( ...@@ -515,6 +517,7 @@ def run_benchmarks(
) )
): ):
try: try:
vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
standard_allreduce_rmsnorm_fp8_quant_native_compiled = torch.compile( standard_allreduce_rmsnorm_fp8_quant_native_compiled = torch.compile(
vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant, vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant,
fullgraph=True, fullgraph=True,
...@@ -580,6 +583,7 @@ def run_benchmarks( ...@@ -580,6 +583,7 @@ def run_benchmarks(
) )
): ):
try: try:
vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
time_ms = benchmark_operation( time_ms = benchmark_operation(
vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant, vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant,
input_tensor, input_tensor,
...@@ -598,6 +602,7 @@ def run_benchmarks( ...@@ -598,6 +602,7 @@ def run_benchmarks(
VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"])) VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"]))
): ):
try: try:
vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
standard_allreduce_rmsnorm_fp4_quant_native_compiled = torch.compile( standard_allreduce_rmsnorm_fp4_quant_native_compiled = torch.compile(
vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant, vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant,
fullgraph=True, fullgraph=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment