Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2bfbdca2
Unverified
Commit
2bfbdca2
authored
Mar 26, 2026
by
Jee Jee Li
Committed by
GitHub
Mar 25, 2026
Browse files
[Bugfix] Fix benchmark_fused_collective.py (#38082)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
29080945
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
7 deletions
+19
-7
benchmarks/kernels/benchmark_fused_collective.py
benchmarks/kernels/benchmark_fused_collective.py
+19
-7
No files found.
benchmarks/kernels/benchmark_fused_collective.py
View file @
2bfbdca2
...
...
@@ -25,6 +25,7 @@ import pandas as pd
import
torch
# type: ignore
import
torch.distributed
as
dist
# type: ignore
from
vllm._custom_ops
import
create_fp4_output_tensors
from
vllm.config.vllm
import
CompilationConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.distributed
import
(
tensor_model_parallel_all_reduce
,
...
...
@@ -46,7 +47,7 @@ RMS_NORM_STATIC_FP8_QUANT_OP = torch.ops._C.rms_norm_static_fp8_quant
FUSED_ADD_RMS_NORM_STATIC_FP8_QUANT_OP
=
(
torch
.
ops
.
_C
.
fused_add_rms_norm_static_fp8_quant
)
SCALED_FP4_QUANT_OP
=
torch
.
ops
.
_C
.
scaled_fp4_quant
SCALED_FP4_QUANT_
OUT_
OP
=
torch
.
ops
.
_C
.
scaled_fp4_quant
.
out
logger
=
init_logger
(
__name__
)
...
...
@@ -334,13 +335,23 @@ class VllmFusedAllreduce:
output_scale
:
torch
.
Tensor
,
):
allreduce_out
=
tensor_model_parallel_all_reduce
(
input_tensor
)
rms_out
=
self
.
rms_norm
(
allreduce_out
,
residual
)
rms_output
=
self
.
rms_norm
(
allreduce_out
,
residual
)
if
residual
is
None
:
rms_out
=
rms_output
else
:
rms_out
,
residual_out
=
rms_output
SCALED_FP4_QUANT_OUT_OP
(
rms_out
,
input_global_scale
,
True
,
output
=
quant_out
,
output_scale
=
output_scale
,
)
if
residual
is
None
:
SCALED_FP4_QUANT_OP
(
quant_out
,
rms_out
,
output_scale
,
input_global_scale
)
return
quant_out
,
output_scale
else
:
rms_out
,
residual_out
=
rms_out
SCALED_FP4_QUANT_OP
(
quant_out
,
rms_out
,
output_scale
,
input_global_scale
)
return
quant_out
,
residual_out
,
output_scale
...
...
@@ -362,8 +373,9 @@ def create_test_tensors(
scale_fp4
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
quant_out_fp8
=
torch
.
empty_like
(
input_tensor
,
dtype
=
FP8_DTYPE
)
# Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks)
fp4_quant_out
=
torch
.
empty
((
num_tokens
,
hidden_dim
//
2
),
dtype
=
torch
.
uint8
)
fp4_output_scale
=
torch
.
empty
((
128
,
4
),
dtype
=
torch
.
int32
)
fp4_quant_out
,
fp4_output_scale
=
create_fp4_output_tensors
(
num_tokens
,
hidden_dim
,
input_tensor
.
device
,
True
)
return
(
input_tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment