Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3e6a1e16
Unverified
Commit
3e6a1e16
authored
Mar 16, 2026
by
Terry Gao
Committed by
GitHub
Mar 16, 2026
Browse files
[Custom Ops] Add functional + out variant for scaled_fp4_quant (#34389)
Signed-off-by:
tianrengao
<
terrygao87@gmail.com
>
parent
7961486a
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
213 additions
and
44 deletions
+213
-44
csrc/ops.h
csrc/ops.h
+8
-4
csrc/quantization/fp4/nvfp4_quant_entry.cu
csrc/quantization/fp4/nvfp4_quant_entry.cu
+34
-3
csrc/quantization/fp4/nvfp4_utils.cuh
csrc/quantization/fp4/nvfp4_utils.cuh
+13
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+15
-4
tests/compile/passes/distributed/test_fusion_all_reduce.py
tests/compile/passes/distributed/test_fusion_all_reduce.py
+1
-1
tests/kernels/quantization/test_nvfp4_quant.py
tests/kernels/quantization/test_nvfp4_quant.py
+46
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+85
-21
vllm/compilation/passes/fusion/act_quant_fusion.py
vllm/compilation/passes/fusion/act_quant_fusion.py
+2
-2
vllm/compilation/passes/fusion/allreduce_rms_fusion.py
vllm/compilation/passes/fusion/allreduce_rms_fusion.py
+5
-5
vllm/compilation/passes/fusion/attn_quant_fusion.py
vllm/compilation/passes/fusion/attn_quant_fusion.py
+2
-2
vllm/compilation/passes/fusion/matcher_utils.py
vllm/compilation/passes/fusion/matcher_utils.py
+1
-1
vllm/compilation/passes/fusion/rms_quant_fusion.py
vllm/compilation/passes/fusion/rms_quant_fusion.py
+1
-1
No files found.
csrc/ops.h
View file @
3e6a1e16
...
@@ -295,10 +295,14 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
...
@@ -295,10 +295,14 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
std
::
vector
<
torch
::
Tensor
>
cutlass_sparse_compress
(
torch
::
Tensor
const
&
a
);
std
::
vector
<
torch
::
Tensor
>
cutlass_sparse_compress
(
torch
::
Tensor
const
&
a
);
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
scaled_fp4_quant_func
(
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_scale
,
torch
::
Tensor
const
&
input_scale
,
bool
is_sf_swizzled_layout
);
bool
is_sf_swizzled_layout
);
void
scaled_fp4_quant_out
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_scale
,
bool
is_sf_swizzled_layout
,
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
);
void
scaled_fp4_experts_quant
(
void
scaled_fp4_experts_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
...
...
csrc/quantization/fp4/nvfp4_quant_entry.cu
View file @
3e6a1e16
...
@@ -16,6 +16,8 @@
...
@@ -16,6 +16,8 @@
#include <torch/all.h>
#include <torch/all.h>
#include "nvfp4_utils.cuh"
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void
scaled_fp4_quant_sm1xxa
(
torch
::
Tensor
const
&
output
,
void
scaled_fp4_quant_sm1xxa
(
torch
::
Tensor
const
&
output
,
...
@@ -51,9 +53,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
...
@@ -51,9 +53,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch
::
Tensor
const
&
output_scale_offset_by_experts
);
torch
::
Tensor
const
&
output_scale_offset_by_experts
);
#endif
#endif
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
void
scaled_fp4_quant_out
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
,
torch
::
Tensor
const
&
input_sf
,
bool
is_sf_swizzled_layout
)
{
bool
is_sf_swizzled_layout
,
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_sf
)
{
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
return
scaled_fp4_quant_sm1xxa
(
output
,
input
,
output_sf
,
input_sf
,
return
scaled_fp4_quant_sm1xxa
(
output
,
input
,
output_sf
,
input_sf
,
...
@@ -62,6 +65,34 @@ void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
...
@@ -62,6 +65,34 @@ void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 quantization kernel"
);
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 quantization kernel"
);
}
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
scaled_fp4_quant_func
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_sf
,
bool
is_sf_swizzled_layout
)
{
int64_t
n
=
input
.
size
(
-
1
);
int64_t
m
=
input
.
numel
()
/
n
;
auto
device
=
input
.
device
();
// Two fp4 values packed into a uint8
auto
output
=
torch
::
empty
(
{
m
,
n
/
2
},
torch
::
TensorOptions
().
device
(
device
).
dtype
(
torch
::
kUInt8
));
torch
::
Tensor
output_sf
;
if
(
is_sf_swizzled_layout
)
{
auto
[
sf_m
,
sf_n
]
=
vllm
::
computeSwizzledSFShape
(
m
,
n
);
output_sf
=
torch
::
empty
(
{
sf_m
,
sf_n
},
torch
::
TensorOptions
().
device
(
device
).
dtype
(
torch
::
kInt32
));
}
else
{
output_sf
=
torch
::
empty
(
{
m
,
n
/
CVT_FP4_SF_VEC_SIZE
},
torch
::
TensorOptions
().
device
(
device
).
dtype
(
torch
::
kUInt8
));
}
scaled_fp4_quant_out
(
input
,
input_sf
,
is_sf_swizzled_layout
,
output
,
output_sf
);
return
{
output
,
output_sf
};
}
void
scaled_fp4_experts_quant
(
void
scaled_fp4_experts_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_global_scale
,
...
...
csrc/quantization/fp4/nvfp4_utils.cuh
View file @
3e6a1e16
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#include <cuda_fp8.h>
#include <utility>
#include "../../cuda_vec_utils.cuh"
#include "../../cuda_vec_utils.cuh"
...
@@ -54,6 +55,18 @@ inline int computeEffectiveRows(int m) {
...
@@ -54,6 +55,18 @@ inline int computeEffectiveRows(int m) {
return
round_up
(
m
,
ROW_TILE
);
return
round_up
(
m
,
ROW_TILE
);
}
}
// Compute the shape of the swizzled SF output tensor.
// Returns (rounded_m, rounded_n / 4) where:
// rounded_m = round_up(m, 128)
// rounded_n = round_up(n / CVT_FP4_SF_VEC_SIZE, 4)
inline
std
::
pair
<
int64_t
,
int64_t
>
computeSwizzledSFShape
(
int64_t
m
,
int64_t
n
)
{
int64_t
rounded_m
=
round_up
(
m
,
static_cast
<
int64_t
>
(
128
));
int64_t
scale_n
=
n
/
CVT_FP4_SF_VEC_SIZE
;
int64_t
rounded_n
=
round_up
(
scale_n
,
static_cast
<
int64_t
>
(
4
));
return
{
rounded_m
,
rounded_n
/
4
};
}
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec8_to_e2m1
(
float
(
&
array
)[
8
])
{
inline
__device__
uint32_t
fp32_vec8_to_e2m1
(
float
(
&
array
)[
8
])
{
uint32_t
val
;
uint32_t
val
;
...
...
csrc/torch_bindings.cpp
View file @
3e6a1e16
...
@@ -564,10 +564,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -564,10 +564,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Compute NVFP4 block quantized tensor.
// Compute NVFP4 block quantized tensor.
ops
.
def
(
ops
.
def
(
"scaled_fp4_quant(Tensor! output, Tensor input,"
"scaled_fp4_quant(Tensor input,"
" Tensor! output_scale, Tensor input_scale, bool "
" Tensor input_scale, bool "
"is_sf_swizzled_layout) -> ()"
);
"is_sf_swizzled_layout) -> (Tensor, Tensor)"
);
ops
.
impl
(
"scaled_fp4_quant"
,
torch
::
kCUDA
,
&
scaled_fp4_quant
);
ops
.
impl
(
"scaled_fp4_quant"
,
torch
::
kCUDA
,
&
scaled_fp4_quant_func
);
// Out variant
// TODO: Add {at::Tag::out_variant} tag and update all call sites
// to use the functional variant once vLLM upgrades PyTorch.
// See pytorch/pytorch#176117.
ops
.
def
(
"scaled_fp4_quant.out(Tensor input,"
" Tensor input_scale, bool "
"is_sf_swizzled_layout, *, Tensor(a!) output, Tensor(b!) output_scale) "
"-> ()"
);
ops
.
impl
(
"scaled_fp4_quant.out"
,
torch
::
kCUDA
,
&
scaled_fp4_quant_out
);
// Compute NVFP4 experts quantization.
// Compute NVFP4 experts quantization.
ops
.
def
(
ops
.
def
(
...
...
tests/compile/passes/distributed/test_fusion_all_reduce.py
View file @
3e6a1e16
...
@@ -179,7 +179,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
...
@@ -179,7 +179,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
def
ops_in_model_before
(
self
):
def
ops_in_model_before
(
self
):
return
[
return
[
torch
.
ops
.
vllm
.
all_reduce
.
default
,
torch
.
ops
.
vllm
.
all_reduce
.
default
,
torch
.
ops
.
_C
.
scaled_fp4_quant
.
defaul
t
,
torch
.
ops
.
_C
.
scaled_fp4_quant
.
ou
t
,
]
]
...
...
tests/kernels/quantization/test_nvfp4_quant.py
View file @
3e6a1e16
...
@@ -159,6 +159,52 @@ def test_quantize_to_fp4(
...
@@ -159,6 +159,52 @@ def test_quantize_to_fp4(
torch
.
testing
.
assert_close
(
scale_ans
,
scale_ref
)
torch
.
testing
.
assert_close
(
scale_ans
,
scale_ref
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
[(
32
,
4096
),
(
128
,
4096
),
(
1
,
64
),
(
127
,
1024
),
(
256
,
16384
)],
)
@
pytest
.
mark
.
parametrize
(
"is_sf_swizzled_layout"
,
[
True
,
False
])
@
torch
.
inference_mode
()
def
test_python_util_matches_cpp_allocation
(
shape
:
tuple
[
int
,
int
],
is_sf_swizzled_layout
:
bool
,
)
->
None
:
"""
Verify that the Python utility (create_fp4_output_tensors) allocates
tensors with the same shapes and dtypes as the C++ functional variant
(scaled_fp4_quant_func).
"""
from
vllm._custom_ops
import
create_fp4_output_tensors
torch
.
set_default_device
(
"cuda:0"
)
m
,
n
=
shape
input_tensor
=
torch
.
randn
((
m
,
n
),
dtype
=
torch
.
bfloat16
)
input_scale
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
"cuda:0"
)
# C++ functional variant allocates internally
cpp_out
,
cpp_scale
=
torch
.
ops
.
_C
.
scaled_fp4_quant
(
input_tensor
,
input_scale
,
is_sf_swizzled_layout
)
# Python utility
py_out
,
py_scale
=
create_fp4_output_tensors
(
m
,
n
,
torch
.
device
(
"cuda:0"
),
is_sf_swizzled_layout
)
assert
py_out
.
shape
==
cpp_out
.
shape
,
(
f
"Output shape mismatch: Python
{
py_out
.
shape
}
vs C++
{
cpp_out
.
shape
}
"
)
assert
py_out
.
dtype
==
cpp_out
.
dtype
,
(
f
"Output dtype mismatch: Python
{
py_out
.
dtype
}
vs C++
{
cpp_out
.
dtype
}
"
)
assert
py_scale
.
shape
==
cpp_scale
.
shape
,
(
f
"Scale shape mismatch: Python
{
py_scale
.
shape
}
vs C++
{
cpp_scale
.
shape
}
"
)
assert
py_scale
.
dtype
==
cpp_scale
.
dtype
,
(
f
"Scale dtype mismatch: Python
{
py_scale
.
dtype
}
vs C++
{
cpp_scale
.
dtype
}
"
)
@
pytest
.
mark
.
parametrize
(
"pad_shape"
,
PAD_SHAPES
)
@
pytest
.
mark
.
parametrize
(
"pad_shape"
,
PAD_SHAPES
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_quantize_to_fp4_padded
(
pad_shape
:
tuple
[
int
,
int
])
->
None
:
def
test_quantize_to_fp4_padded
(
pad_shape
:
tuple
[
int
,
int
])
->
None
:
...
...
vllm/_custom_ops.py
View file @
3e6a1e16
...
@@ -29,6 +29,81 @@ else:
...
@@ -29,6 +29,81 @@ else:
from
torch.library
import
impl_abstract
as
register_fake
from
torch.library
import
impl_abstract
as
register_fake
# scaled_fp4_quant functional + out variant for torch.compile buffer management
def
create_fp4_scale_tensor
(
m
:
int
,
n
:
int
,
device
:
torch
.
device
,
is_sf_swizzled_layout
:
bool
,
)
->
torch
.
Tensor
:
"""
Allocate the output scale tensor for scaled_fp4_quant.
When is_sf_swizzled_layout=True, we use rounded values to store the
swizzled scales. Due to the requirement of the Tensor Core, the minimum
tile is 128x4 for the scales. So, we first pad the scales to multiples
of 128 (rows) and 4 (cols). Then, the scales (in float8_e4m3fn) are
packed into an int32 for every 4 values. More:
https://docs.nvidia.com/cuda/parallel-thread-execution/
#tcgen05-mma-scale-factor-b-layout-4x
"""
from
vllm.utils.math_utils
import
round_up
block_size
=
16
if
is_sf_swizzled_layout
:
rounded_m
=
round_up
(
m
,
128
)
scale_n
=
n
//
block_size
rounded_n
=
round_up
(
scale_n
,
4
)
return
torch
.
empty
(
(
rounded_m
,
rounded_n
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
else
:
return
torch
.
empty
((
m
,
n
//
block_size
),
device
=
device
,
dtype
=
torch
.
uint8
)
def
create_fp4_output_tensors
(
m
:
int
,
n
:
int
,
device
:
torch
.
device
,
is_sf_swizzled_layout
:
bool
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Allocate both output tensors for scaled_fp4_quant:
(quantized_output, output_scale).
Must match the C++ scaled_fp4_quant_func allocation exactly.
"""
output
=
torch
.
empty
((
m
,
n
//
2
),
device
=
device
,
dtype
=
torch
.
uint8
)
output_scale
=
create_fp4_scale_tensor
(
m
,
n
,
device
,
is_sf_swizzled_layout
)
return
output
,
output_scale
if
hasattr
(
torch
.
ops
,
"_C"
)
and
hasattr
(
torch
.
ops
.
_C
,
"scaled_fp4_quant"
):
@
register_fake
(
"_C::scaled_fp4_quant"
)
def
_scaled_fp4_quant_fake
(
input
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
,
is_sf_swizzled_layout
:
bool
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
n
=
input
.
shape
[
-
1
]
m
=
input
.
numel
()
//
n
return
create_fp4_output_tensors
(
m
,
n
,
input
.
device
,
is_sf_swizzled_layout
)
@
register_fake
(
"_C::scaled_fp4_quant.out"
)
def
_scaled_fp4_quant_out_fake
(
input
:
torch
.
Tensor
,
input_scale
:
torch
.
Tensor
,
is_sf_swizzled_layout
:
bool
,
*
,
output
:
torch
.
Tensor
,
output_scale
:
torch
.
Tensor
,
)
->
None
:
return
None
# page attention ops
# page attention ops
def
paged_attention_v1
(
def
paged_attention_v1
(
out
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
...
@@ -1644,7 +1719,6 @@ def scaled_fp4_quant(
...
@@ -1644,7 +1719,6 @@ def scaled_fp4_quant(
input
=
input
.
reshape
(
other_dims
,
input
.
shape
[
-
1
])
input
=
input
.
reshape
(
other_dims
,
input
.
shape
[
-
1
])
m
,
n
=
input
.
shape
m
,
n
=
input
.
shape
block_size
=
16
block_size
=
16
device
=
input
.
device
assert
n
%
block_size
==
0
,
f
"last dim has to be multiple of 16, but got
{
n
}
."
assert
n
%
block_size
==
0
,
f
"last dim has to be multiple of 16, but got
{
n
}
."
assert
input
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
),
(
assert
input
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
),
(
...
@@ -1658,26 +1732,16 @@ def scaled_fp4_quant(
...
@@ -1658,26 +1732,16 @@ def scaled_fp4_quant(
input
,
input_global_scale
input
,
input_global_scale
)
)
else
:
else
:
# Two fp4 values will be packed into an uint8.
# Pre-allocate and call .out variant (same behavior as old in-place API)
output
=
torch
.
empty
((
m
,
n
//
2
),
device
=
device
,
dtype
=
torch
.
uint8
)
output
,
output_scale
=
create_fp4_output_tensors
(
if
is_sf_swizzled_layout
:
m
,
n
,
input
.
device
,
is_sf_swizzled_layout
# We use the rounded values to store the swizzled values. Due to the
)
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
torch
.
ops
.
_C
.
scaled_fp4_quant
.
out
(
# So, we first pad the scales to multiples of 128 and 4. Then, the scales
input
,
# (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
input_global_scale
,
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
is_sf_swizzled_layout
,
round_up
=
lambda
x
,
y
:
(
x
+
y
-
1
)
//
y
*
y
output
=
output
,
rounded_m
=
round_up
(
m
,
128
)
output_scale
=
output_scale
,
scale_n
=
n
//
block_size
rounded_n
=
round_up
(
scale_n
,
4
)
output_scale
=
torch
.
empty
(
(
rounded_m
,
rounded_n
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
else
:
output_scale
=
torch
.
empty
((
m
,
n
//
16
),
device
=
device
,
dtype
=
torch
.
uint8
)
torch
.
ops
.
_C
.
scaled_fp4_quant
(
output
,
input
,
output_scale
,
input_global_scale
,
is_sf_swizzled_layout
)
)
output_scale
=
output_scale
.
view
(
torch
.
float8_e4m3fn
)
output_scale
=
output_scale
.
view
(
torch
.
float8_e4m3fn
)
...
...
vllm/compilation/passes/fusion/act_quant_fusion.py
View file @
3e6a1e16
...
@@ -148,11 +148,11 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
...
@@ -148,11 +148,11 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
result_silu_mul
=
self
.
silu_and_mul_matcher
(
input
)
result_silu_mul
=
self
.
silu_and_mul_matcher
(
input
)
at
=
auto_functionalized
(
at
=
auto_functionalized
(
self
.
QUANT_OP
,
self
.
QUANT_OP
,
output
=
result
,
input
=
result_silu_mul
,
input
=
result_silu_mul
,
output_scale
=
output_scale
,
input_scale
=
scale
,
input_scale
=
scale
,
is_sf_swizzled_layout
=
True
,
is_sf_swizzled_layout
=
True
,
output
=
result
,
output_scale
=
output_scale
,
)
)
return
at
[
1
],
at
[
2
]
return
at
[
1
],
at
[
2
]
...
...
vllm/compilation/passes/fusion/allreduce_rms_fusion.py
View file @
3e6a1e16
...
@@ -47,7 +47,7 @@ if find_spec("flashinfer"):
...
@@ -47,7 +47,7 @@ if find_spec("flashinfer"):
pass
pass
if
hasattr
(
torch
.
ops
.
_C
,
"scaled_fp4_quant"
):
if
hasattr
(
torch
.
ops
.
_C
,
"scaled_fp4_quant"
):
STATIC_FP4_QUANT_OP
=
torch
.
ops
.
_C
.
scaled_fp4_quant
.
defaul
t
STATIC_FP4_QUANT_OP
=
torch
.
ops
.
_C
.
scaled_fp4_quant
.
ou
t
# Max size of the input tensor per world size per device capability
# Max size of the input tensor per world size per device capability
# to use flashinfer fused allreduce
# to use flashinfer fused allreduce
...
@@ -562,11 +562,11 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
...
@@ -562,11 +562,11 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
rms
=
self
.
rmsnorm_matcher
(
all_reduce
,
weight
)
rms
=
self
.
rmsnorm_matcher
(
all_reduce
,
weight
)
quant_out_tuple
=
auto_functionalized
(
quant_out_tuple
=
auto_functionalized
(
STATIC_FP4_QUANT_OP
,
STATIC_FP4_QUANT_OP
,
output
=
quant_result
,
input
=
rms
,
input
=
rms
,
output_scale
=
output_scale
,
input_scale
=
input_global_scale
,
input_scale
=
input_global_scale
,
is_sf_swizzled_layout
=
True
,
is_sf_swizzled_layout
=
True
,
output
=
quant_result
,
output_scale
=
output_scale
,
)
)
# quant_out, allreduce_output, output_scale
# quant_out, allreduce_output, output_scale
...
@@ -660,11 +660,11 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
...
@@ -660,11 +660,11 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
rms
,
residual
=
self
.
rmsnorm_matcher
(
allreduce_output
,
weight
,
residual
)
rms
,
residual
=
self
.
rmsnorm_matcher
(
allreduce_output
,
weight
,
residual
)
quant_out_tuple
=
auto_functionalized
(
quant_out_tuple
=
auto_functionalized
(
STATIC_FP4_QUANT_OP
,
STATIC_FP4_QUANT_OP
,
output
=
quant_result
,
input
=
rms
,
input
=
rms
,
output_scale
=
output_scale
,
input_scale
=
input_global_scale
,
input_scale
=
input_global_scale
,
is_sf_swizzled_layout
=
True
,
is_sf_swizzled_layout
=
True
,
output
=
quant_result
,
output_scale
=
output_scale
,
)
)
# quant_out, allreduce_output, output_scale
# quant_out, allreduce_output, output_scale
...
...
vllm/compilation/passes/fusion/attn_quant_fusion.py
View file @
3e6a1e16
...
@@ -250,11 +250,11 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
...
@@ -250,11 +250,11 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
)
)
at2
=
auto_functionalized
(
at2
=
auto_functionalized
(
self
.
QUANT_OP
,
self
.
QUANT_OP
,
output
=
output_quant
,
input
=
attn_out_view
,
input
=
attn_out_view
,
output_scale
=
output_scale
,
input_scale
=
input_scale
,
input_scale
=
input_scale
,
is_sf_swizzled_layout
=
True
,
is_sf_swizzled_layout
=
True
,
output
=
output_quant
,
output_scale
=
output_scale
,
)
)
output_scale_view
=
torch
.
ops
.
aten
.
view
.
dtype
(
at2
[
2
],
FP8_DTYPE
)
output_scale_view
=
torch
.
ops
.
aten
.
view
.
dtype
(
at2
[
2
],
FP8_DTYPE
)
return
at2
[
1
],
output_scale_view
return
at2
[
1
],
output_scale_view
...
...
vllm/compilation/passes/fusion/matcher_utils.py
View file @
3e6a1e16
...
@@ -38,7 +38,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
...
@@ -38,7 +38,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
}
}
if
current_platform
.
is_cuda
()
and
hasattr
(
torch
.
ops
.
_C
,
"scaled_fp4_quant"
):
if
current_platform
.
is_cuda
()
and
hasattr
(
torch
.
ops
.
_C
,
"scaled_fp4_quant"
):
QUANT_OPS
[
kNvfp4Dynamic
]
=
torch
.
ops
.
_C
.
scaled_fp4_quant
.
defaul
t
# noqa: E501
QUANT_OPS
[
kNvfp4Dynamic
]
=
torch
.
ops
.
_C
.
scaled_fp4_quant
.
ou
t
# noqa: E501
if
current_platform
.
is_cuda
():
if
current_platform
.
is_cuda
():
QUANT_OPS
[
kFp8Dynamic128Sym
]
=
torch
.
ops
.
_C
.
per_token_group_fp8_quant
.
default
# noqa: E501
QUANT_OPS
[
kFp8Dynamic128Sym
]
=
torch
.
ops
.
_C
.
per_token_group_fp8_quant
.
default
# noqa: E501
...
...
vllm/compilation/passes/fusion/rms_quant_fusion.py
View file @
3e6a1e16
...
@@ -63,7 +63,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
...
@@ -63,7 +63,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8DynamicTokenSym
:
torch
.
ops
.
_C
.
dynamic_per_token_scaled_fp8_quant
.
default
,
# noqa: E501
kFp8DynamicTokenSym
:
torch
.
ops
.
_C
.
dynamic_per_token_scaled_fp8_quant
.
default
,
# noqa: E501
}
}
if
current_platform
.
is_cuda
()
and
hasattr
(
torch
.
ops
.
_C
,
"scaled_fp4_quant"
):
if
current_platform
.
is_cuda
()
and
hasattr
(
torch
.
ops
.
_C
,
"scaled_fp4_quant"
):
QUANT_OPS
[
kNvfp4Dynamic
]
=
torch
.
ops
.
_C
.
scaled_fp4_quant
.
defaul
t
QUANT_OPS
[
kNvfp4Dynamic
]
=
torch
.
ops
.
_C
.
scaled_fp4_quant
.
ou
t
if
current_platform
.
is_cuda
():
if
current_platform
.
is_cuda
():
QUANT_OPS
[
kFp8Dynamic128Sym
]
=
torch
.
ops
.
_C
.
per_token_group_fp8_quant
.
default
# noqa: E501
QUANT_OPS
[
kFp8Dynamic128Sym
]
=
torch
.
ops
.
_C
.
per_token_group_fp8_quant
.
default
# noqa: E501
QUANT_OPS
[
kFp8Dynamic64Sym
]
=
torch
.
ops
.
_C
.
per_token_group_fp8_quant
.
default
# noqa: E501
QUANT_OPS
[
kFp8Dynamic64Sym
]
=
torch
.
ops
.
_C
.
per_token_group_fp8_quant
.
default
# noqa: E501
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment