Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fcb9df99
Unverified
Commit
fcb9df99
authored
Jan 25, 2026
by
Roberto L. Castro
Committed by
GitHub
Jan 24, 2026
Browse files
[Perf][Kernel] Optimize FP4 quantization kernels (SM100F) (#32520)
Signed-off-by:
LopezCastroRoberto
<
rocastro@redhat.com
>
parent
1ebdff41
Changes
18
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
509 additions
and
152 deletions
+509
-152
benchmarks/kernels/bench_nvfp4_quant.py
benchmarks/kernels/bench_nvfp4_quant.py
+55
-22
csrc/ops.h
csrc/ops.h
+2
-1
csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu
...quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu
+59
-20
csrc/quantization/fp4/nvfp4_experts_quant.cu
csrc/quantization/fp4/nvfp4_experts_quant.cu
+4
-4
csrc/quantization/fp4/nvfp4_quant_entry.cu
csrc/quantization/fp4/nvfp4_quant_entry.cu
+6
-3
csrc/quantization/fp4/nvfp4_quant_kernels.cu
csrc/quantization/fp4/nvfp4_quant_kernels.cu
+141
-47
csrc/quantization/fp4/nvfp4_utils.cuh
csrc/quantization/fp4/nvfp4_utils.cuh
+171
-25
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+2
-1
tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
...s/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
+6
-2
tests/kernels/quantization/test_nvfp4_quant.py
tests/kernels/quantization/test_nvfp4_quant.py
+28
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+18
-13
vllm/compilation/activation_quant_fusion.py
vllm/compilation/activation_quant_fusion.py
+1
-0
vllm/compilation/collective_fusion.py
vllm/compilation/collective_fusion.py
+2
-0
vllm/compilation/fusion_attn.py
vllm/compilation/fusion_attn.py
+1
-0
vllm/model_executor/layers/fused_moe/utils.py
vllm/model_executor/layers/fused_moe/utils.py
+1
-4
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
...mpressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
+4
-1
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+3
-1
vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
..._executor/layers/quantization/utils/flashinfer_fp4_moe.py
+5
-8
No files found.
benchmarks/kernels/bench_nvfp4_quant.py
View file @
fcb9df99
...
@@ -20,8 +20,12 @@ FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
...
@@ -20,8 +20,12 @@ FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
PROVIDER_CFGS
=
{
PROVIDER_CFGS
=
{
"vllm"
:
dict
(
backend
=
"vllm"
,
enabled
=
True
),
"vllm"
:
dict
(
backend
=
"vllm"
,
is_sf_swizzled_layout
=
False
,
enabled
=
True
),
"flashinfer"
:
dict
(
backend
=
"flashinfer"
,
enabled
=
True
),
"vllm-swizzle"
:
dict
(
backend
=
"vllm"
,
is_sf_swizzled_layout
=
True
,
enabled
=
True
),
"flashinfer"
:
dict
(
backend
=
"flashinfer"
,
is_sf_swizzled_layout
=
False
,
enabled
=
True
),
"flashinfer-swizzle"
:
dict
(
backend
=
"flashinfer"
,
is_sf_swizzled_layout
=
True
,
enabled
=
True
),
}
}
_enabled
=
[
k
for
k
,
v
in
PROVIDER_CFGS
.
items
()
if
v
[
"enabled"
]]
_enabled
=
[
k
for
k
,
v
in
PROVIDER_CFGS
.
items
()
if
v
[
"enabled"
]]
...
@@ -36,7 +40,7 @@ def compute_global_scale(tensor: torch.Tensor) -> torch.Tensor:
...
@@ -36,7 +40,7 @@ def compute_global_scale(tensor: torch.Tensor) -> torch.Tensor:
@
triton
.
testing
.
perf_report
(
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
],
x_vals
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
],
x_log
=
False
,
x_log
=
False
,
line_arg
=
"provider"
,
line_arg
=
"provider"
,
line_vals
=
_enabled
,
line_vals
=
_enabled
,
...
@@ -63,19 +67,36 @@ def benchmark(batch_size, provider, N, K):
...
@@ -63,19 +67,36 @@ def benchmark(batch_size, provider, N, K):
if
cfg
[
"backend"
]
==
"vllm"
:
if
cfg
[
"backend"
]
==
"vllm"
:
# vLLM's FP4 quantization
# vLLM's FP4 quantization
if
cfg
[
"is_sf_swizzled_layout"
]:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
ops
.
scaled_fp4_quant
(
a
,
a_global_scale
),
lambda
:
ops
.
scaled_fp4_quant
(
a
,
a_global_scale
,
is_sf_swizzled_layout
=
True
),
quantiles
=
quantiles
,
)
else
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
ops
.
scaled_fp4_quant
(
a
,
a_global_scale
,
is_sf_swizzled_layout
=
False
),
quantiles
=
quantiles
,
quantiles
=
quantiles
,
)
)
elif
cfg
[
"backend"
]
==
"flashinfer"
:
elif
cfg
[
"backend"
]
==
"flashinfer"
:
# FlashInfer's FP4 quantization
# FlashInfer's FP4 quantization
# Use
is_sf_swizzled_layout
=True to match vLLM's output format
if
cfg
[
"
is_sf_swizzled_layout
"
]:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
flashinfer_fp4_quantize
(
lambda
:
flashinfer_fp4_quantize
(
a
,
a_global_scale
,
is_sf_swizzled_layout
=
True
a
,
a_global_scale
,
is_sf_swizzled_layout
=
True
),
),
quantiles
=
quantiles
,
quantiles
=
quantiles
,
)
)
else
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
flashinfer_fp4_quantize
(
a
,
a_global_scale
,
is_sf_swizzled_layout
=
False
),
quantiles
=
quantiles
,
)
# Convert ms to us for better readability at small batch sizes
# Convert ms to us for better readability at small batch sizes
to_us
=
lambda
t_ms
:
t_ms
*
1000
to_us
=
lambda
t_ms
:
t_ms
*
1000
...
@@ -92,7 +113,9 @@ def prepare_shapes(args):
...
@@ -92,7 +113,9 @@ def prepare_shapes(args):
return
out
return
out
def
_test_accuracy_once
(
M
:
int
,
K
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
):
def
_test_accuracy_once
(
M
:
int
,
K
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
is_sf_swizzled_layout
:
bool
):
"""Test accuracy between vLLM and FlashInfer FP4 quantization."""
"""Test accuracy between vLLM and FlashInfer FP4 quantization."""
# Create input tensor
# Create input tensor
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
...
@@ -101,11 +124,13 @@ def _test_accuracy_once(M: int, K: int, dtype: torch.dtype, device: str):
...
@@ -101,11 +124,13 @@ def _test_accuracy_once(M: int, K: int, dtype: torch.dtype, device: str):
a_global_scale
=
compute_global_scale
(
a
)
a_global_scale
=
compute_global_scale
(
a
)
# vLLM quantization
# vLLM quantization
vllm_fp4
,
vllm_scale
=
ops
.
scaled_fp4_quant
(
a
,
a_global_scale
)
vllm_fp4
,
vllm_scale
=
ops
.
scaled_fp4_quant
(
a
,
a_global_scale
,
is_sf_swizzled_layout
=
is_sf_swizzled_layout
)
# FlashInfer quantization (with swizzled layout to match vLLM's output)
# FlashInfer quantization (with swizzled layout to match vLLM's output)
flashinfer_fp4
,
flashinfer_scale
=
flashinfer_fp4_quantize
(
flashinfer_fp4
,
flashinfer_scale
=
flashinfer_fp4_quantize
(
a
,
a_global_scale
,
is_sf_swizzled_layout
=
True
a
,
a_global_scale
,
is_sf_swizzled_layout
=
is_sf_swizzled_layout
)
)
flashinfer_scale
=
flashinfer_scale
.
view
(
torch
.
float8_e4m3fn
)
flashinfer_scale
=
flashinfer_scale
.
view
(
torch
.
float8_e4m3fn
)
...
@@ -114,7 +139,14 @@ def _test_accuracy_once(M: int, K: int, dtype: torch.dtype, device: str):
...
@@ -114,7 +139,14 @@ def _test_accuracy_once(M: int, K: int, dtype: torch.dtype, device: str):
vllm_fp4
,
vllm_fp4
,
flashinfer_fp4
,
flashinfer_fp4
,
)
)
print
(
f
"M=
{
M
}
, K=
{
K
}
, dtype=
{
dtype
}
: PASSED"
)
# Compare scales
torch
.
testing
.
assert_close
(
vllm_scale
,
flashinfer_scale
,
)
print
(
f
"M=
{
M
}
, K=
{
K
}
, dtype=
{
dtype
}
, is_sf_swizzled_layout=
{
is_sf_swizzled_layout
}
: PASSED"
# noqa: E501
)
def
test_accuracy
():
def
test_accuracy
():
...
@@ -130,9 +162,10 @@ def test_accuracy():
...
@@ -130,9 +162,10 @@ def test_accuracy():
Ms
=
[
1
,
1024
]
Ms
=
[
1
,
1024
]
Ks
=
[
4096
]
Ks
=
[
4096
]
for
is_sf_swizzled_layout
in
[
True
,
False
]:
for
M
in
Ms
:
for
M
in
Ms
:
for
K
in
Ks
:
for
K
in
Ks
:
_test_accuracy_once
(
M
,
K
,
dtype
,
device
)
_test_accuracy_once
(
M
,
K
,
dtype
,
device
,
is_sf_swizzled_layout
)
print
(
"
\n
All accuracy tests passed!"
)
print
(
"
\n
All accuracy tests passed!"
)
...
@@ -145,7 +178,7 @@ if __name__ == "__main__":
...
@@ -145,7 +178,7 @@ if __name__ == "__main__":
"--models"
,
"--models"
,
nargs
=
"+"
,
nargs
=
"+"
,
type
=
str
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.
1-8
B-Instruct"
],
default
=
[
"meta-llama/Llama-3.
3-70
B-Instruct"
],
choices
=
list
(
WEIGHT_SHAPES
.
keys
()),
choices
=
list
(
WEIGHT_SHAPES
.
keys
()),
)
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
])
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
])
...
...
csrc/ops.h
View file @
fcb9df99
...
@@ -293,7 +293,8 @@ std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
...
@@ -293,7 +293,8 @@ std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input_scale
);
torch
::
Tensor
const
&
input_scale
,
bool
is_sf_swizzled_layout
);
void
scaled_fp4_experts_quant
(
void
scaled_fp4_experts_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
&
output
,
torch
::
Tensor
&
output_scale
,
...
...
csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu
View file @
fcb9df99
...
@@ -27,17 +27,24 @@
...
@@ -27,17 +27,24 @@
#include "cuda_utils.h"
#include "cuda_utils.h"
#include "launch_bounds_utils.h"
#include "launch_bounds_utils.h"
// Define before including nvfp4_utils.cuh so the header
// can use this macro during compilation.
#define NVFP4_ENABLE_ELTS16 1
#include "nvfp4_utils.cuh"
#include "nvfp4_utils.cuh"
namespace
vllm
{
namespace
vllm
{
// Use UE4M3 by default.
// Use UE4M3 by default.
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__global__
void
__launch_bounds__
(
1024
,
VLLM_BLOCKS_PER_SM
(
1024
))
__global__
void
__launch_bounds__
(
512
,
VLLM_BLOCKS_PER_SM
(
512
))
silu_mul_cvt_fp16_to_fp4
(
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
silu_mul_cvt_fp16_to_fp4
(
int32_t
numRows
,
int32_t
numCols
,
float
const
*
SFScale
,
uint32_t
*
out
,
int32_t
num_padded_cols
,
uint32_t
*
SFout
)
{
Type
const
*
__restrict__
in
,
using
PackedVec
=
PackedVec
<
Type
>
;
float
const
*
__restrict__
SFScale
,
uint32_t
*
__restrict__
out
,
uint32_t
*
__restrict__
SFout
)
{
using
PackedVec
=
vllm
::
PackedVec
<
Type
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
...
@@ -49,34 +56,60 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
...
@@ -49,34 +56,60 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
// Get the global scaling factor, which will be applied to the SF.
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
// (448.f / (Alpha_A / 6.f)).
float
const
SFScaleVal
=
SFScale
==
nullptr
?
1.0
f
:
SFScale
[
0
];
float
const
SFScaleVal
=
(
SFScale
==
nullptr
)
?
1.0
f
:
SFScale
[
0
];
int32_t
const
colIdx
=
blockDim
.
x
*
blockIdx
.
y
+
threadIdx
.
x
;
int
elem_idx
=
colIdx
*
CVT_FP4_ELTS_PER_THREAD
;
// Input tensor row/col loops.
// Input tensor row/col loops.
for
(
int
rowIdx
=
blockIdx
.
x
;
rowIdx
<
numRows
;
rowIdx
+=
gridDim
.
x
)
{
for
(
int
rowIdx
=
blockIdx
.
x
;
rowIdx
<
numRows
;
rowIdx
+=
gridDim
.
x
)
{
for
(
int
colIdx
=
threadIdx
.
x
;
colIdx
<
numCols
/
CVT_FP4_ELTS_PER_THREAD
;
if
(
colIdx
<
num_padded_cols
)
{
colIdx
+=
blockDim
.
x
)
{
PackedVec
in_vec
;
PackedVec
in_vec2
;
int64_t
inOffset
=
int64_t
inOffset
=
rowIdx
*
(
numCols
*
2
/
CVT_FP4_ELTS_PER_THREAD
)
+
colIdx
;
rowIdx
*
(
numCols
*
2
/
CVT_FP4_ELTS_PER_THREAD
)
+
colIdx
;
int64_t
inOffset2
=
rowIdx
*
(
numCols
*
2
/
CVT_FP4_ELTS_PER_THREAD
)
+
int64_t
inOffset2
=
rowIdx
*
(
numCols
*
2
/
CVT_FP4_ELTS_PER_THREAD
)
+
numCols
/
CVT_FP4_ELTS_PER_THREAD
+
colIdx
;
numCols
/
CVT_FP4_ELTS_PER_THREAD
+
colIdx
;
PackedVec
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
PackedVec
in_vec2
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset2
];
// Get the output tensor offset.
bool
valid
=
(
rowIdx
<
numRows
)
&&
(
elem_idx
<
numCols
);
// Same as inOffset because 8 elements are packed into one uint32_t.
if
constexpr
(
CVT_FP4_PACK16
)
{
int64_t
outOffset
=
rowIdx
*
(
numCols
/
CVT_FP4_ELTS_PER_THREAD
)
+
colIdx
;
ld256_or_zero_cg_u32
<
Type
>
(
auto
&
out_pos
=
out
[
outOffset
];
in_vec
,
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset
*
8
],
valid
);
ld256_or_zero_cg_u32
<
Type
>
(
in_vec2
,
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset2
*
8
],
valid
);
}
else
{
ld128_or_zero_cg_u32
<
Type
>
(
in_vec
,
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset
*
4
],
valid
);
ld128_or_zero_cg_u32
<
Type
>
(
in_vec2
,
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset2
*
4
],
valid
);
}
// Compute silu and mul
// Compute silu and mul
PackedVec
out_silu_mul
=
compute_silu_mul
(
in_vec
,
in_vec2
);
PackedVec
out_silu_mul
=
compute_silu_mul
<
Type
>
(
in_vec
,
in_vec2
);
auto
sf_out
=
auto
sf_out
=
cvt_quant_to_fp4_get_sf_out_offset
<
uint32_t
,
cvt_quant_to_fp4_get_sf_out_offset
<
uint32_t
,
CVT_FP4_NUM_THREADS_PER_SF
>
(
CVT_FP4_NUM_THREADS_PER_SF
>
(
rowIdx
,
colIdx
,
numKTiles
,
SFout
);
rowIdx
,
colIdx
,
numKTiles
,
SFout
);
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
out_silu_mul
,
SFScaleVal
,
auto
out_val
=
sf_out
);
cvt_warp_fp16_to_fp4
<
Type
,
CVT_FP4_NUM_THREADS_PER_SF
,
UE8M0_SF
>
(
out_silu_mul
,
SFScaleVal
,
sf_out
);
if
(
valid
)
{
if
constexpr
(
CVT_FP4_PACK16
)
{
int64_t
outOffset
=
rowIdx
*
(
numCols
/
8
)
+
colIdx
*
2
;
uint64_t
packed64
=
(
uint64_t
(
out_val
.
hi
)
<<
32
)
|
uint64_t
(
out_val
.
lo
);
reinterpret_cast
<
uint64_t
*>
(
out
)[
outOffset
>>
1
]
=
packed64
;
}
else
{
out
[
inOffset
]
=
out_val
;
}
}
}
}
}
}
}
}
...
@@ -103,17 +136,23 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
...
@@ -103,17 +136,23 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
auto
output_ptr
=
static_cast
<
int64_t
*>
(
output
.
data_ptr
());
auto
output_ptr
=
static_cast
<
int64_t
*>
(
output
.
data_ptr
());
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
input
.
get_device
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
input
.
get_device
());
dim3
block
(
std
::
min
(
int
(
n
/
ELTS_PER_THREAD
),
1024
));
dim3
block
(
std
::
min
(
int
(
n
/
ELTS_PER_THREAD
),
512
));
int
const
numBlocksPerSM
=
int
const
numBlocksPerSM
=
vllm_runtime_blocks_per_sm
(
static_cast
<
int
>
(
block
.
x
));
vllm_runtime_blocks_per_sm
(
static_cast
<
int
>
(
block
.
x
));
dim3
grid
(
std
::
min
(
int
(
m
),
multiProcessorCount
*
numBlocksPerSM
));
int
sf_n_unpadded
=
int
(
n
/
CVT_FP4_SF_VEC_SIZE
);
int
grid_y
=
vllm
::
div_round_up
(
sf_n_unpadded
,
static_cast
<
int
>
(
block
.
x
));
int
grid_x
=
std
::
min
(
int
(
m
),
std
::
max
(
1
,
(
multiProcessorCount
*
numBlocksPerSM
)
/
grid_y
));
dim3
grid
(
grid_x
,
grid_y
);
VLLM_DISPATCH_HALF_TYPES
(
VLLM_DISPATCH_HALF_TYPES
(
input
.
scalar_type
(),
"silu_and_mul_nvfp4_quant_kernel"
,
[
&
]
{
input
.
scalar_type
(),
"silu_and_mul_nvfp4_quant_kernel"
,
[
&
]
{
using
cuda_type
=
vllm
::
CUDATypeConverter
<
scalar_t
>::
Type
;
using
cuda_type
=
vllm
::
CUDATypeConverter
<
scalar_t
>::
Type
;
auto
input_ptr
=
static_cast
<
cuda_type
const
*>
(
input
.
data_ptr
());
auto
input_ptr
=
static_cast
<
cuda_type
const
*>
(
input
.
data_ptr
());
vllm
::
silu_mul_cvt_fp16_to_fp4
<
cuda_type
><<<
grid
,
block
,
0
,
stream
>>>
(
vllm
::
silu_mul_cvt_fp16_to_fp4
<
cuda_type
><<<
grid
,
block
,
0
,
stream
>>>
(
m
,
n
,
input_ptr
,
input_sf_ptr
,
m
,
n
,
sf_n_unpadded
,
input_ptr
,
input_sf_ptr
,
reinterpret_cast
<
uint32_t
*>
(
output_ptr
),
reinterpret_cast
<
uint32_t
*>
(
output_ptr
),
reinterpret_cast
<
uint32_t
*>
(
sf_out
));
reinterpret_cast
<
uint32_t
*>
(
sf_out
));
});
});
...
...
csrc/quantization/fp4/nvfp4_experts_quant.cu
View file @
fcb9df99
...
@@ -140,8 +140,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
...
@@ -140,8 +140,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
CVT_FP4_NUM_THREADS_PER_SF
>
(
CVT_FP4_NUM_THREADS_PER_SF
>
(
rowIdx_in_expert
,
colIdx
,
numKTiles
,
SFout_in_expert
);
rowIdx_in_expert
,
colIdx
,
numKTiles
,
SFout_in_expert
);
out_pos
=
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
CVT_FP4_NUM_THREADS_PER_SF
,
UE8M0_SF
>
(
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
quant_input
,
SFScaleVal
,
sf_out
);
quant_input
,
SFScaleVal
,
sf_out
);
}
}
}
}
...
@@ -246,8 +246,8 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
...
@@ -246,8 +246,8 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
CVT_FP4_NUM_THREADS_PER_SF
>
(
CVT_FP4_NUM_THREADS_PER_SF
>
(
rowIdx_in_expert
,
colIdx
,
numKTiles
,
SFout_in_expert
);
rowIdx_in_expert
,
colIdx
,
numKTiles
,
SFout_in_expert
);
out_pos
=
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
CVT_FP4_NUM_THREADS_PER_SF
,
UE8M0_SF
>
(
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
quant_input
,
SFScaleVal
,
sf_out
);
quant_input
,
SFScaleVal
,
sf_out
);
}
}
}
}
...
...
csrc/quantization/fp4/nvfp4_quant_entry.cu
View file @
fcb9df99
...
@@ -21,7 +21,8 @@
...
@@ -21,7 +21,8 @@
void
scaled_fp4_quant_sm1xxa
(
torch
::
Tensor
const
&
output
,
void
scaled_fp4_quant_sm1xxa
(
torch
::
Tensor
const
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
output_sf
,
torch
::
Tensor
const
&
output_sf
,
torch
::
Tensor
const
&
input_sf
);
torch
::
Tensor
const
&
input_sf
,
bool
is_sf_swizzled_layout
);
#endif
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
...
@@ -51,10 +52,12 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
...
@@ -51,10 +52,12 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
#endif
#endif
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
)
{
torch
::
Tensor
&
output_sf
,
torch
::
Tensor
const
&
input_sf
,
bool
is_sf_swizzled_layout
)
{
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
return
scaled_fp4_quant_sm1xxa
(
output
,
input
,
output_sf
,
input_sf
);
return
scaled_fp4_quant_sm1xxa
(
output
,
input
,
output_sf
,
input_sf
,
is_sf_swizzled_layout
);
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 quantization kernel"
);
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled nvfp4 quantization kernel"
);
}
}
...
...
csrc/quantization/fp4/nvfp4_quant_kernels.cu
View file @
fcb9df99
...
@@ -27,29 +27,23 @@
...
@@ -27,29 +27,23 @@
#include "cuda_utils.h"
#include "cuda_utils.h"
#include "launch_bounds_utils.h"
#include "launch_bounds_utils.h"
// Define before including nvfp4_utils.cuh so the header
// can use this macro during compilation.
#define NVFP4_ENABLE_ELTS16 1
#include "nvfp4_utils.cuh"
#include "nvfp4_utils.cuh"
namespace
vllm
{
namespace
vllm
{
template
<
typename
Int
>
__host__
__device__
inline
Int
round_up
(
Int
x
,
Int
y
)
{
static_assert
(
std
::
is_integral_v
<
Int
>
,
"round_up argument must be integral type"
);
return
((
x
+
y
-
1
)
/
y
)
*
y
;
}
// Compute effective rows for grid configuration with swizzled SF layouts.
inline
int
computeEffectiveRows
(
int
m
)
{
constexpr
int
ROW_TILE
=
128
;
return
round_up
(
m
,
ROW_TILE
);
}
// Use UE4M3 by default.
// Use UE4M3 by default.
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__global__
void
__launch_bounds__
(
512
,
VLLM_BLOCKS_PER_SM
(
512
))
__global__
void
__launch_bounds__
(
512
,
VLLM_BLOCKS_PER_SM
(
512
))
cvt_fp16_to_fp4
(
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
cvt_fp16_to_fp4
(
int32_t
numRows
,
int32_t
numCols
,
int32_t
num_padded_cols
,
float
const
*
SFScale
,
uint32_t
*
out
,
uint32_t
*
SFout
)
{
Type
const
*
__restrict__
in
,
using
PackedVec
=
PackedVec
<
Type
>
;
float
const
*
__restrict__
SFScale
,
uint32_t
*
__restrict__
out
,
uint32_t
*
__restrict__
SFout
)
{
using
PackedVec
=
vllm
::
PackedVec
<
Type
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
...
@@ -59,33 +53,31 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
...
@@ -59,33 +53,31 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
int32_t
const
numKTiles
=
(
numCols
+
63
)
/
64
;
int32_t
const
numKTiles
=
(
numCols
+
63
)
/
64
;
int
sf_m
=
round_up
<
int
>
(
numRows
,
128
);
int
sf_m
=
round_up
<
int
>
(
numRows
,
128
);
int
sf_n_unpadded
=
numCols
/
CVT_FP4_SF_VEC_SIZE
;
int32_t
const
colIdx
=
blockDim
.
x
*
blockIdx
.
y
+
threadIdx
.
x
;
int
sf_n_int
=
round_up
<
int
>
(
sf_n_unpadded
,
4
)
/
4
;
int
elem_idx
=
colIdx
*
CVT_FP4_ELTS_PER_THREAD
;
int
num_padded_cols
=
sf_n_int
*
4
*
CVT_FP4_SF_VEC_SIZE
;
// Get the global scaling factor, which will be applied to the SF.
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
// (448.f / (Alpha_A / 6.f)).
float
const
global_scale
=
SFScale
==
nullptr
?
1.0
f
:
SFScale
[
0
];
float
const
global_scale
=
(
SFScale
==
nullptr
)
?
1.0
f
:
SFScale
[
0
];
// Iterate over all rows and cols including padded ones -
// Iterate over all rows and cols including padded ones -
// ensures we visit every single scale factor address to initialize it.
// ensures we visit every single scale factor address to initialize it.
for
(
int
rowIdx
=
blockIdx
.
x
;
rowIdx
<
sf_m
;
rowIdx
+=
gridDim
.
x
)
{
for
(
int
rowIdx
=
blockIdx
.
x
;
rowIdx
<
sf_m
;
rowIdx
+=
gridDim
.
x
)
{
for
(
int
colIdx
=
threadIdx
.
x
;
if
(
colIdx
<
num_padded_cols
)
{
colIdx
<
num_padded_cols
/
CVT_FP4_ELTS_PER_THREAD
;
colIdx
+=
blockDim
.
x
)
{
int
elem_idx
=
colIdx
*
CVT_FP4_ELTS_PER_THREAD
;
PackedVec
in_vec
;
PackedVec
in_vec
;
int64_t
inOffset
=
rowIdx
*
(
numCols
/
CVT_FP4_ELTS_PER_THREAD
)
+
colIdx
;
int64_t
inOffset
=
rowIdx
*
(
numCols
/
CVT_FP4_ELTS_PER_THREAD
)
+
colIdx
;
// If we are outside valid rows OR outside valid columns -> Use Zeros
// If we are outside valid rows OR outside valid columns -> Use Zeros
if
(
rowIdx
>=
numRows
||
elem_idx
>=
numCols
)
{
bool
valid
=
(
rowIdx
<
numRows
)
&&
(
elem_idx
<
numCols
);
memset
(
&
in_vec
,
0
,
sizeof
(
PackedVec
));
if
constexpr
(
CVT_FP4_PACK16
)
{
ld256_or_zero_cg_u32
<
Type
>
(
in_vec
,
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset
*
8
],
valid
);
}
else
{
}
else
{
// Valid Region: Load actual data
ld128_or_zero_cg_u32
<
Type
>
(
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
in_vec
,
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset
*
4
],
valid
);
}
}
auto
sf_out
=
auto
sf_out
=
...
@@ -94,16 +86,88 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
...
@@ -94,16 +86,88 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
rowIdx
,
colIdx
,
numKTiles
,
SFout
);
rowIdx
,
colIdx
,
numKTiles
,
SFout
);
auto
out_val
=
auto
out_val
=
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
in_vec
,
global_scale
,
sf_out
);
cvt_warp_fp16_to_fp4
<
Type
,
CVT_FP4_NUM_THREADS_PER_SF
,
UE8M0_SF
>
(
in_vec
,
global_scale
,
sf_out
);
// We do NOT write output for padding because the 'out' tensor is not
// padded.
if
(
valid
)
{
if
constexpr
(
CVT_FP4_PACK16
)
{
int64_t
outOffset
=
rowIdx
*
(
numCols
/
8
)
+
colIdx
*
2
;
uint64_t
packed64
=
(
uint64_t
(
out_val
.
hi
)
<<
32
)
|
uint64_t
(
out_val
.
lo
);
reinterpret_cast
<
uint64_t
*>
(
out
)[
outOffset
>>
1
]
=
packed64
;
}
else
{
out
[
inOffset
]
=
out_val
;
}
}
}
}
}
// Use UE4M3 by default.
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__global__
void
__launch_bounds__
(
512
,
VLLM_BLOCKS_PER_SM
(
512
))
cvt_fp16_to_fp4_sf_major
(
int32_t
numRows
,
int32_t
numCols
,
int32_t
sf_n_unpadded
,
Type
const
*
__restrict__
in
,
float
const
*
__restrict__
SFScale
,
uint32_t
*
__restrict__
out
,
uint32_t
*
__restrict__
SFout
)
{
using
PackedVec
=
PackedVec
<
Type
>
;
static
constexpr
int
CVT_FP4_NUM_THREADS_PER_SF
=
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
"Vec size is not matched."
);
int32_t
const
colIdx
=
blockDim
.
x
*
blockIdx
.
y
+
threadIdx
.
x
;
int
elem_idx
=
colIdx
*
CVT_FP4_ELTS_PER_THREAD
;
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float
const
global_scale
=
(
SFScale
==
nullptr
)
?
1.0
f
:
SFScale
[
0
];
// Iterate over all rows and cols including padded ones -
// ensures we visit every single scale factor address to initialize it.
for
(
int
rowIdx
=
blockIdx
.
x
;
rowIdx
<
numRows
;
rowIdx
+=
gridDim
.
x
)
{
if
(
colIdx
<
sf_n_unpadded
)
{
PackedVec
in_vec
;
int64_t
inOffset
=
rowIdx
*
(
numCols
/
CVT_FP4_ELTS_PER_THREAD
)
+
colIdx
;
// If we are outside valid rows OR outside valid columns -> Use Zeros
bool
valid
=
(
rowIdx
<
numRows
)
&&
(
elem_idx
<
numCols
);
if
constexpr
(
CVT_FP4_PACK16
)
{
ld256_or_zero_cg_u32
<
Type
>
(
in_vec
,
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset
*
8
],
valid
);
}
else
{
ld128_or_zero_cg_u32
<
Type
>
(
in_vec
,
&
reinterpret_cast
<
const
uint32_t
*>
(
in
)[
inOffset
*
4
],
valid
);
}
auto
sf_out
=
sf_out_rowmajor_u8
<
uint32_t
>
(
rowIdx
,
colIdx
,
sf_n_unpadded
,
SFout
);
auto
out_val
=
cvt_warp_fp16_to_fp4
<
Type
,
CVT_FP4_NUM_THREADS_PER_SF
,
UE8M0_SF
>
(
in_vec
,
global_scale
,
sf_out
);
// We do NOT write output for padding because the 'out' tensor is not
// We do NOT write output for padding because the 'out' tensor is not
// padded.
// padded.
if
(
rowIdx
<
numRows
&&
elem_idx
<
numCols
)
{
if
(
valid
)
{
// Same as inOffset because 8 elements are packed into one uint32_t.
if
constexpr
(
CVT_FP4_PACK16
)
{
int64_t
outOffset
=
rowIdx
*
(
numCols
/
8
)
+
colIdx
*
2
;
uint64_t
packed64
=
(
uint64_t
(
out_val
.
hi
)
<<
32
)
|
uint64_t
(
out_val
.
lo
);
reinterpret_cast
<
uint64_t
*>
(
out
)[
outOffset
>>
1
]
=
packed64
;
}
else
{
out
[
inOffset
]
=
out_val
;
out
[
inOffset
]
=
out_val
;
}
}
}
}
}
}
}
}
}
}
// namespace vllm
}
// namespace vllm
...
@@ -111,7 +175,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
...
@@ -111,7 +175,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
void
scaled_fp4_quant_sm1xxa
(
torch
::
Tensor
const
&
output
,
void
scaled_fp4_quant_sm1xxa
(
torch
::
Tensor
const
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
output_sf
,
torch
::
Tensor
const
&
output_sf
,
torch
::
Tensor
const
&
input_sf
)
{
torch
::
Tensor
const
&
input_sf
,
bool
is_sf_swizzled_layout
)
{
int32_t
m
=
input
.
size
(
0
);
int32_t
m
=
input
.
size
(
0
);
int32_t
n
=
input
.
size
(
1
);
int32_t
n
=
input
.
size
(
1
);
...
@@ -129,19 +194,48 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
...
@@ -129,19 +194,48 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
input
.
get_device
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
input
.
get_device
());
int
sf_n_unpadded
=
int
(
n
/
CVT_FP4_SF_VEC_SIZE
);
// Grid, Block size. Each thread converts 8 values.
// Grid, Block size. Each thread converts 8 values.
dim3
block
(
std
::
min
(
int
(
n
/
ELTS_PER_THREAD
),
512
));
dim3
block
(
std
::
min
(
int
(
n
/
ELTS_PER_THREAD
),
512
));
int
const
numBlocksPerSM
=
int
const
numBlocksPerSM
=
vllm_runtime_blocks_per_sm
(
static_cast
<
int
>
(
block
.
x
));
vllm_runtime_blocks_per_sm
(
static_cast
<
int
>
(
block
.
x
));
int
effectiveRows
=
vllm
::
computeEffectiveRows
(
m
);
dim3
grid
(
std
::
min
(
effectiveRows
,
multiProcessorCount
*
numBlocksPerSM
));
if
(
is_sf_swizzled_layout
)
{
int
sf_n_int
=
int
(
vllm
::
round_up
(
sf_n_unpadded
,
4
)
/
4
);
int32_t
num_padded_cols
=
sf_n_int
*
4
*
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
;
int
grid_y
=
vllm
::
div_round_up
(
num_padded_cols
,
static_cast
<
int
>
(
block
.
x
));
int
grid_x
=
std
::
min
(
vllm
::
computeEffectiveRows
(
m
),
std
::
max
(
1
,
(
multiProcessorCount
*
numBlocksPerSM
)
/
grid_y
));
dim3
grid
(
grid_x
,
grid_y
);
VLLM_DISPATCH_HALF_TYPES
(
input
.
scalar_type
(),
"nvfp4_quant_kernel"
,
[
&
]
{
VLLM_DISPATCH_HALF_TYPES
(
input
.
scalar_type
(),
"nvfp4_quant_kernel"
,
[
&
]
{
using
cuda_type
=
vllm
::
CUDATypeConverter
<
scalar_t
>::
Type
;
using
cuda_type
=
vllm
::
CUDATypeConverter
<
scalar_t
>::
Type
;
auto
input_ptr
=
static_cast
<
cuda_type
const
*>
(
input
.
data_ptr
());
auto
input_ptr
=
static_cast
<
cuda_type
const
*>
(
input
.
data_ptr
());
// NOTE: We don't support e8m0 scales at this moment.
// NOTE: We don't support e8m0 scales at this moment.
vllm
::
cvt_fp16_to_fp4
<
cuda_type
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
vllm
::
cvt_fp16_to_fp4
<
cuda_type
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
m
,
n
,
input_ptr
,
input_sf_ptr
,
reinterpret_cast
<
uint32_t
*>
(
output_ptr
),
m
,
n
,
num_padded_cols
,
input_ptr
,
input_sf_ptr
,
reinterpret_cast
<
uint32_t
*>
(
output_ptr
),
reinterpret_cast
<
uint32_t
*>
(
sf_out
));
});
}
else
{
int
grid_y
=
vllm
::
div_round_up
(
sf_n_unpadded
,
static_cast
<
int
>
(
block
.
x
));
int
grid_x
=
std
::
min
(
m
,
std
::
max
(
1
,
(
multiProcessorCount
*
numBlocksPerSM
)
/
grid_y
));
dim3
grid
(
grid_x
,
grid_y
);
VLLM_DISPATCH_HALF_TYPES
(
input
.
scalar_type
(),
"nvfp4_quant_kernel"
,
[
&
]
{
using
cuda_type
=
vllm
::
CUDATypeConverter
<
scalar_t
>::
Type
;
auto
input_ptr
=
static_cast
<
cuda_type
const
*>
(
input
.
data_ptr
());
// NOTE: We don't support e8m0 scales at this moment.
vllm
::
cvt_fp16_to_fp4_sf_major
<
cuda_type
,
false
>
<<<
grid
,
block
,
0
,
stream
>>>
(
m
,
n
,
sf_n_unpadded
,
input_ptr
,
input_sf_ptr
,
reinterpret_cast
<
uint32_t
*>
(
output_ptr
),
reinterpret_cast
<
uint32_t
*>
(
sf_out
));
reinterpret_cast
<
uint32_t
*>
(
sf_out
));
});
});
}
}
}
csrc/quantization/fp4/nvfp4_utils.cuh
View file @
fcb9df99
...
@@ -19,9 +19,17 @@
...
@@ -19,9 +19,17 @@
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#include <cuda_fp8.h>
#define ELTS_PER_THREAD 8
#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \
defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100)
#define ELTS_PER_THREAD 16
constexpr
int
CVT_FP4_ELTS_PER_THREAD
=
16
;
constexpr
bool
CVT_FP4_PACK16
=
true
;
#else
#define ELTS_PER_THREAD 8
constexpr
int
CVT_FP4_ELTS_PER_THREAD
=
8
;
constexpr
int
CVT_FP4_ELTS_PER_THREAD
=
8
;
constexpr
bool
CVT_FP4_PACK16
=
false
;
#endif
constexpr
int
CVT_FP4_SF_VEC_SIZE
=
16
;
constexpr
int
CVT_FP4_SF_VEC_SIZE
=
16
;
namespace
vllm
{
namespace
vllm
{
...
@@ -68,19 +76,46 @@ struct TypeConverter<__nv_bfloat16> {
...
@@ -68,19 +76,46 @@ struct TypeConverter<__nv_bfloat16> {
using
Type
=
__nv_bfloat162
;
using
Type
=
__nv_bfloat162
;
};
};
#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \
defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100)
// Define a 32 bytes packed data type.
template
<
class
Type
>
struct
alignas
(
32
)
PackedVec
{
typename
TypeConverter
<
Type
>::
Type
elts
[
8
];
};
#else
// Define a 16 bytes packed data type.
// Define a 16 bytes packed data type.
template
<
class
Type
>
template
<
class
Type
>
struct
PackedVec
{
struct
alignas
(
16
)
PackedVec
{
typename
TypeConverter
<
Type
>::
Type
elts
[
4
];
typename
TypeConverter
<
Type
>::
Type
elts
[
4
];
};
};
#endif
template
<
>
template
<
>
struct
PackedVec
<
__nv_fp8_e4m3
>
{
struct
PackedVec
<
__nv_fp8_e4m3
>
{
__nv_fp8x2_e4m3
elts
[
8
];
__nv_fp8x2_e4m3
elts
[
8
];
};
};
template
<
typename
Int
>
__host__
__device__
inline
Int
round_up
(
Int
x
,
Int
y
)
{
static_assert
(
std
::
is_integral_v
<
Int
>
,
"round_up argument must be integral type"
);
return
((
x
+
y
-
1
)
/
y
)
*
y
;
}
template
<
typename
Int
>
__host__
__device__
__forceinline__
Int
div_round_up
(
Int
x
,
Int
y
)
{
return
(
x
+
y
-
1
)
/
y
;
}
// Compute effective rows for grid configuration with swizzled SF layouts.
inline
int
computeEffectiveRows
(
int
m
)
{
constexpr
int
ROW_TILE
=
128
;
return
round_up
(
m
,
ROW_TILE
);
}
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float
(
&
array
)[
8
])
{
inline
__device__
uint32_t
fp32_vec
8
_to_e2m1
(
float
(
&
array
)[
8
])
{
uint32_t
val
;
uint32_t
val
;
asm
volatile
(
asm
volatile
(
"{
\n
"
"{
\n
"
...
@@ -101,7 +136,7 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
...
@@ -101,7 +136,7 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
}
}
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline
__device__
uint32_t
fp32_vec_to_e2m1
(
float2
(
&
array
)[
4
])
{
__device__
__forceinline__
uint32_t
fp32_vec
8
_to_e2m1
(
float2
(
&
array
)[
4
])
{
uint32_t
val
;
uint32_t
val
;
asm
volatile
(
asm
volatile
(
"{
\n
"
"{
\n
"
...
@@ -114,20 +149,115 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
...
@@ -114,20 +149,115 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"mov.b32 %0, {byte0, byte1, byte2, byte3};
\n
"
"}"
"}
\n
"
:
"=r"
(
val
)
:
"=r"
(
val
)
:
"f"
(
array
[
0
].
x
),
"f"
(
array
[
0
].
y
),
"f"
(
array
[
1
].
x
),
"f"
(
array
[
1
].
y
),
:
"f"
(
array
[
0
].
x
),
"f"
(
array
[
0
].
y
),
"f"
(
array
[
1
].
x
),
"f"
(
array
[
1
].
y
),
"f"
(
array
[
2
].
x
),
"f"
(
array
[
2
].
y
),
"f"
(
array
[
3
].
x
),
"f"
(
array
[
3
].
y
));
"f"
(
array
[
2
].
x
),
"f"
(
array
[
2
].
y
),
"f"
(
array
[
3
].
x
),
"f"
(
array
[
3
].
y
));
return
val
;
return
val
;
}
}
struct
u32x2
{
uint32_t
lo
,
hi
;
};
using
fp4_packed_t
=
std
::
conditional_t
<
CVT_FP4_PACK16
,
u32x2
,
uint32_t
>
;
__device__
__forceinline__
u32x2
fp32_vec16_to_e2m1
(
float2
(
&
array
)[
8
])
{
u32x2
out
;
asm
volatile
(
"{
\n
"
".reg .b8 b0;
\n
"
".reg .b8 b1;
\n
"
".reg .b8 b2;
\n
"
".reg .b8 b3;
\n
"
".reg .b8 b4;
\n
"
".reg .b8 b5;
\n
"
".reg .b8 b6;
\n
"
".reg .b8 b7;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 b0, %3, %2;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 b1, %5, %4;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 b2, %7, %6;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 b3, %9, %8;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 b4, %11, %10;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 b5, %13, %12;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 b6, %15, %14;
\n
"
"cvt.rn.satfinite.e2m1x2.f32 b7, %17, %16;
\n
"
"mov.b32 %0, {b0, b1, b2, b3};
\n
"
"mov.b32 %1, {b4, b5, b6, b7};
\n
"
"}
\n
"
:
"=r"
(
out
.
lo
),
"=r"
(
out
.
hi
)
:
"f"
(
array
[
0
].
x
),
"f"
(
array
[
0
].
y
),
"f"
(
array
[
1
].
x
),
"f"
(
array
[
1
].
y
),
"f"
(
array
[
2
].
x
),
"f"
(
array
[
2
].
y
),
"f"
(
array
[
3
].
x
),
"f"
(
array
[
3
].
y
),
"f"
(
array
[
4
].
x
),
"f"
(
array
[
4
].
y
),
"f"
(
array
[
5
].
x
),
"f"
(
array
[
5
].
y
),
"f"
(
array
[
6
].
x
),
"f"
(
array
[
6
].
y
),
"f"
(
array
[
7
].
x
),
"f"
(
array
[
7
].
y
));
return
out
;
}
__device__
__forceinline__
uint32_t
pack_fp4
(
float2
(
&
v
)[
4
])
{
return
fp32_vec8_to_e2m1
(
v
);
}
__device__
__forceinline__
u32x2
pack_fp4
(
float2
(
&
v
)[
8
])
{
return
fp32_vec16_to_e2m1
(
v
);
}
// Fast reciprocal.
// Fast reciprocal.
inline
__device__
float
reciprocal_approximate_ftz
(
float
a
)
{
__devic
e__
__forceinlin
e__
float
reciprocal_approximate_ftz
(
float
a
)
{
float
b
;
float
b
;
asm
volatile
(
"rcp.approx.ftz.f32 %0, %1;
\n
"
:
"=f"
(
b
)
:
"f"
(
a
));
asm
volatile
(
"rcp.approx.ftz.f32 %0, %1;"
:
"=f"
(
b
)
:
"f"
(
a
));
return
b
;
return
b
;
}
}
template
<
class
Type
>
__device__
__forceinline__
void
ld128_or_zero_cg_u32
(
PackedVec
<
Type
>&
out
,
const
void
*
ptr
,
bool
pred
)
{
uint32_t
r0
,
r1
,
r2
,
r3
;
asm
volatile
(
"{
\n
"
" .reg .pred pr;
\n
"
" setp.ne.u32 pr, %4, 0;
\n
"
" mov.u32 %0, 0;
\n
"
" mov.u32 %1, 0;
\n
"
" mov.u32 %2, 0;
\n
"
" mov.u32 %3, 0;
\n
"
" @pr ld.global.cg.v4.u32 {%0,%1,%2,%3}, [%5];
\n
"
"}
\n
"
:
"=r"
(
r0
),
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
)
:
"r"
((
int
)
pred
),
"l"
(
ptr
));
*
reinterpret_cast
<
uint4
*>
(
&
out
)
=
uint4
{
r0
,
r1
,
r2
,
r3
};
}
template
<
class
Type
>
__device__
__forceinline__
void
ld256_or_zero_cg_u32
(
PackedVec
<
Type
>&
out
,
const
void
*
ptr
,
bool
pred
)
{
uint32_t
r0
,
r1
,
r2
,
r3
,
r4
,
r5
,
r6
,
r7
;
asm
volatile
(
"{
\n
"
" .reg .pred pr;
\n
"
" setp.ne.u32 pr, %8, 0;
\n
"
" mov.u32 %0, 0;
\n
"
" mov.u32 %1, 0;
\n
"
" mov.u32 %2, 0;
\n
"
" mov.u32 %3, 0;
\n
"
" mov.u32 %4, 0;
\n
"
" mov.u32 %5, 0;
\n
"
" mov.u32 %6, 0;
\n
"
" mov.u32 %7, 0;
\n
"
" @pr ld.global.cg.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%9];
\n
"
"}
\n
"
:
"=r"
(
r0
),
"=r"
(
r1
),
"=r"
(
r2
),
"=r"
(
r3
),
"=r"
(
r4
),
"=r"
(
r5
),
"=r"
(
r6
),
"=r"
(
r7
)
:
"r"
((
int
)
pred
),
"l"
(
ptr
));
reinterpret_cast
<
uint4
*>
(
&
out
)[
0
]
=
uint4
{
r0
,
r1
,
r2
,
r3
};
reinterpret_cast
<
uint4
*>
(
&
out
)[
1
]
=
uint4
{
r4
,
r5
,
r6
,
r7
};
}
// Compute SF output offset for swizzled tensor core layout.
// Compute SF output offset for swizzled tensor core layout.
// SF layout: [numMTiles, numKTiles, 32, 4, 4]
// SF layout: [numMTiles, numKTiles, 32, 4, 4]
// Caller must precompute: numKTiles = (numCols + 63) / 64
// Caller must precompute: numKTiles = (numCols + 63) / 64
...
@@ -166,21 +296,41 @@ __device__ __forceinline__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(
...
@@ -166,21 +296,41 @@ __device__ __forceinline__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(
return
reinterpret_cast
<
uint8_t
*>
(
SFout
)
+
SFOffset
;
return
reinterpret_cast
<
uint8_t
*>
(
SFout
)
+
SFOffset
;
}
}
template
<
class
SFType
>
__device__
__forceinline__
uint8_t
*
sf_out_rowmajor_u8
(
int
row
,
int
pack
,
int
packs_per_row_sf
,
SFType
*
SFout
)
{
constexpr
int
PACK
=
CVT_FP4_ELTS_PER_THREAD
;
constexpr
int
THREADS_PER_SF
=
CVT_FP4_SF_VEC_SIZE
/
PACK
;
// 1 if PACK=16, 2 else PACK=8
if
(
threadIdx
.
x
%
THREADS_PER_SF
!=
0
)
return
nullptr
;
int
sf_col
=
pack
/
THREADS_PER_SF
;
// PACK=16 => sf_col=pack; PACK=8 => sf_col=pack/2
int64_t
off
=
(
int64_t
)
row
*
packs_per_row_sf
+
sf_col
;
return
(
uint8_t
*
)
SFout
+
off
;
}
// Quantizes the provided PackedVec into the uint32_t output
// Quantizes the provided PackedVec into the uint32_t output
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
template
<
class
Type
,
int
CVT_FP4_NUM_THREADS_PER_SF
,
bool
UE8M0_SF
=
false
>
__device__
uint32_t
cvt_warp_fp16_to_fp4
(
PackedVec
<
Type
>&
vec
,
float
SFScaleVal
,
__device__
__forceinline__
fp4_packed_t
uint8_t
*
SFout
)
{
cvt_warp_fp16_to_fp4
(
PackedVec
<
Type
>&
vec
,
float
SFScaleVal
,
uint8_t
*
SFout
)
{
// Get absolute maximum values among the local 8 values.
// Get absolute maximum values among the local 8 values.
auto
localMax
=
__habs2
(
vec
.
elts
[
0
]);
auto
localMax
=
__habs2
(
vec
.
elts
[
0
]);
// Local maximum value.
// Local maximum value.
#pragma unroll
#pragma unroll
for
(
int
i
=
1
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
for
(
int
i
=
1
;
i
<
CVT_FP4_ELTS_PER_THREAD
/
2
;
i
++
)
{
localMax
=
__hmax2
(
localMax
,
__habs2
(
vec
.
elts
[
i
]));
localMax
=
__hmax2
(
localMax
,
__habs2
(
vec
.
elts
[
i
]));
}
}
// Get the absolute maximum among all 16 values (two threads).
// Get the absolute maximum among all 16 values (two threads).
localMax
=
__hmax2
(
__shfl_xor_sync
(
uint32_t
(
-
1
),
localMax
,
1
),
localMax
);
if
constexpr
(
CVT_FP4_NUM_THREADS_PER_SF
==
2
)
{
localMax
=
__hmax2
(
__shfl_xor_sync
(
0xffffffffu
,
localMax
,
1
),
localMax
);
}
// Get the final absolute maximum values.
// Get the final absolute maximum values.
float
vecMax
=
float
(
__hmax
(
localMax
.
x
,
localMax
.
y
));
float
vecMax
=
float
(
__hmax
(
localMax
.
x
,
localMax
.
y
));
...
@@ -205,19 +355,18 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
...
@@ -205,19 +355,18 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
// Convert back to fp32.
// Convert back to fp32.
SFValue
=
float
(
tmp
);
SFValue
=
float
(
tmp
);
}
}
// Write the SF to global memory (STG.8).
if
(
SFout
)
*
SFout
=
fp8SFVal
;
// Get the output scale.
// Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal))
// reciprocal(SFScaleVal))
float
outputScale
=
float
outputScale
=
SFValue
!=
0
?
reciprocal_approximate_ftz
(
SFValue
!=
0
.0
f
?
reciprocal_approximate_ftz
(
SFValue
*
reciprocal_approximate_ftz
(
SFScaleVal
))
SFValue
*
reciprocal_approximate_ftz
(
SFScaleVal
))
:
0.0
f
;
:
0.0
f
;
if
(
SFout
)
{
// Write the SF to global memory (STG.8).
*
SFout
=
fp8SFVal
;
}
// Convert the input to float.
// Convert the input to float.
float2
fp2Vals
[
CVT_FP4_ELTS_PER_THREAD
/
2
];
float2
fp2Vals
[
CVT_FP4_ELTS_PER_THREAD
/
2
];
...
@@ -233,10 +382,7 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
...
@@ -233,10 +382,7 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
}
}
// Convert to e2m1 values.
// Convert to e2m1 values.
uint32_t
e2m1Vec
=
fp32_vec_to_e2m1
(
fp2Vals
);
return
pack_fp4
(
fp2Vals
);
// Write the e2m1 values to global memory.
return
e2m1Vec
;
}
}
// silu in float32
// silu in float32
...
...
csrc/torch_bindings.cpp
View file @
fcb9df99
...
@@ -546,7 +546,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -546,7 +546,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Compute NVFP4 block quantized tensor.
// Compute NVFP4 block quantized tensor.
ops
.
def
(
ops
.
def
(
"scaled_fp4_quant(Tensor! output, Tensor input,"
"scaled_fp4_quant(Tensor! output, Tensor input,"
" Tensor! output_scale, Tensor input_scale) -> ()"
);
" Tensor! output_scale, Tensor input_scale, bool "
"is_sf_swizzled_layout) -> ()"
);
ops
.
impl
(
"scaled_fp4_quant"
,
torch
::
kCUDA
,
&
scaled_fp4_quant
);
ops
.
impl
(
"scaled_fp4_quant"
,
torch
::
kCUDA
,
&
scaled_fp4_quant
);
// Compute NVFP4 experts quantization.
// Compute NVFP4 experts quantization.
...
...
tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
View file @
fcb9df99
...
@@ -107,10 +107,14 @@ def test_flashinfer_nvfp4_gemm(
...
@@ -107,10 +107,14 @@ def test_flashinfer_nvfp4_gemm(
# from checkpoints are in linear scales.
# from checkpoints are in linear scales.
# So instead of needing to swizzle for cutlass as in modelopt.py,
# So instead of needing to swizzle for cutlass as in modelopt.py,
# we need to unswizzle for trtllm here.
# we need to unswizzle for trtllm here.
a_fp4
,
a_scale_interleaved
=
ops
.
scaled_fp4_quant
(
a_dtype
,
a_global_scale
,
backend
)
a_fp4
,
a_scale_interleaved
=
ops
.
scaled_fp4_quant
(
a_dtype
,
a_global_scale
,
is_sf_swizzled_layout
=
True
,
backend
=
backend
)
is_sf_128x4_layout
=
not
(
backend
==
"trtllm"
and
m
<=
32
)
is_sf_128x4_layout
=
not
(
backend
==
"trtllm"
and
m
<=
32
)
b_fp4
,
b_scale_interleaved
=
ops
.
scaled_fp4_quant
(
b_dtype
,
b_global_scale
)
b_fp4
,
b_scale_interleaved
=
ops
.
scaled_fp4_quant
(
b_dtype
,
b_global_scale
,
is_sf_swizzled_layout
=
True
)
# get_ref_results unswizzles the scales internally.
# get_ref_results unswizzles the scales internally.
expected_out
=
get_ref_results
(
expected_out
=
get_ref_results
(
...
...
tests/kernels/quantization/test_nvfp4_quant.py
View file @
fcb9df99
...
@@ -27,6 +27,12 @@ PAD_SHAPES = [
...
@@ -27,6 +27,12 @@ PAD_SHAPES = [
(
150
,
128
),
(
150
,
128
),
(
150
,
48
),
(
150
,
48
),
(
90
,
80
),
(
90
,
80
),
(
128
,
512
),
(
128
,
1024
),
(
128
,
2048
),
(
64
,
7168
),
(
64
,
7152
),
(
32
,
14336
),
]
]
SEEDS
=
[
42
]
SEEDS
=
[
42
]
CUDA_DEVICES
=
[
"cuda:0"
]
CUDA_DEVICES
=
[
"cuda:0"
]
...
@@ -173,3 +179,25 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
...
@@ -173,3 +179,25 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
out_ans
=
cast_from_fp4
(
out
,
m
,
n
)
out_ans
=
cast_from_fp4
(
out
,
m
,
n
)
torch
.
testing
.
assert_close
(
out_ans
,
out_ref
)
torch
.
testing
.
assert_close
(
out_ans
,
out_ref
)
torch
.
testing
.
assert_close
(
scale_ans
,
scale_ref
)
torch
.
testing
.
assert_close
(
scale_ans
,
scale_ref
)
@
pytest
.
mark
.
parametrize
(
"pad_shape"
,
PAD_SHAPES
)
@
torch
.
inference_mode
()
def
test_quantize_to_fp4_padded_no_sf_swizzled
(
pad_shape
:
tuple
[
int
,
int
])
->
None
:
dtype
=
torch
.
float16
set_random_seed
(
42
)
torch
.
set_default_device
(
"cuda:0"
)
m
,
n
=
pad_shape
x
=
torch
.
randn
((
m
,
n
),
dtype
=
dtype
)
tensor_amax
=
torch
.
abs
(
x
).
max
().
to
(
torch
.
float32
)
global_scale
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
tensor_amax
out_ref
,
scale_ref
=
ref_nvfp4_quant
(
x
,
global_scale
)
out
,
out_scale
=
ops
.
scaled_fp4_quant
(
x
,
global_scale
,
is_sf_swizzled_layout
=
False
)
scale_ans
=
out_scale
.
to
(
torch
.
float32
)
out_ans
=
cast_from_fp4
(
out
,
m
,
n
)
torch
.
testing
.
assert_close
(
out_ans
,
out_ref
)
torch
.
testing
.
assert_close
(
scale_ans
,
scale_ref
)
vllm/_custom_ops.py
View file @
fcb9df99
...
@@ -1534,6 +1534,7 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
...
@@ -1534,6 +1534,7 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
def
scaled_fp4_quant
(
def
scaled_fp4_quant
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
,
is_sf_swizzled_layout
:
bool
=
True
,
backend
:
str
=
"none"
,
backend
:
str
=
"none"
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
...
@@ -1577,7 +1578,7 @@ def scaled_fp4_quant(
...
@@ -1577,7 +1578,7 @@ def scaled_fp4_quant(
else
:
else
:
# Two fp4 values will be packed into an uint8.
# Two fp4 values will be packed into an uint8.
output
=
torch
.
empty
((
m
,
n
//
2
),
device
=
device
,
dtype
=
torch
.
uint8
)
output
=
torch
.
empty
((
m
,
n
//
2
),
device
=
device
,
dtype
=
torch
.
uint8
)
if
is_sf_swizzled_layout
:
# We use the rounded values to store the swizzled values. Due to the
# We use the rounded values to store the swizzled values. Due to the
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
# So, we first pad the scales to multiples of 128 and 4. Then, the scales
# So, we first pad the scales to multiples of 128 and 4. Then, the scales
...
@@ -1590,8 +1591,12 @@ def scaled_fp4_quant(
...
@@ -1590,8 +1591,12 @@ def scaled_fp4_quant(
output_scale
=
torch
.
empty
(
output_scale
=
torch
.
empty
(
(
rounded_m
,
rounded_n
//
4
),
device
=
device
,
dtype
=
torch
.
int32
(
rounded_m
,
rounded_n
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
)
else
:
output_scale
=
torch
.
empty
((
m
,
n
//
16
),
device
=
device
,
dtype
=
torch
.
uint8
)
torch
.
ops
.
_C
.
scaled_fp4_quant
(
output
,
input
,
output_scale
,
input_global_scale
)
torch
.
ops
.
_C
.
scaled_fp4_quant
(
output
,
input
,
output_scale
,
input_global_scale
,
is_sf_swizzled_layout
)
output_scale
=
output_scale
.
view
(
torch
.
float8_e4m3fn
)
output_scale
=
output_scale
.
view
(
torch
.
float8_e4m3fn
)
return
output
,
output_scale
return
output
,
output_scale
...
...
vllm/compilation/activation_quant_fusion.py
View file @
fcb9df99
...
@@ -152,6 +152,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
...
@@ -152,6 +152,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
input
=
result_silu_mul
,
input
=
result_silu_mul
,
output_scale
=
output_scale
,
output_scale
=
output_scale
,
input_scale
=
scale
,
input_scale
=
scale
,
is_sf_swizzled_layout
=
True
,
)
)
return
at
[
1
],
at
[
2
]
return
at
[
1
],
at
[
2
]
...
...
vllm/compilation/collective_fusion.py
View file @
fcb9df99
...
@@ -946,6 +946,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
...
@@ -946,6 +946,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
input
=
rms
,
input
=
rms
,
output_scale
=
output_scale
,
output_scale
=
output_scale
,
input_scale
=
input_global_scale
,
input_scale
=
input_global_scale
,
is_sf_swizzled_layout
=
True
,
)
)
# quant_out, allreduce_output, output_scale
# quant_out, allreduce_output, output_scale
...
@@ -1043,6 +1044,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
...
@@ -1043,6 +1044,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
input
=
rms
,
input
=
rms
,
output_scale
=
output_scale
,
output_scale
=
output_scale
,
input_scale
=
input_global_scale
,
input_scale
=
input_global_scale
,
is_sf_swizzled_layout
=
True
,
)
)
# quant_out, allreduce_output, output_scale
# quant_out, allreduce_output, output_scale
...
...
vllm/compilation/fusion_attn.py
View file @
fcb9df99
...
@@ -248,6 +248,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
...
@@ -248,6 +248,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
input
=
attn_out_view
,
input
=
attn_out_view
,
output_scale
=
output_scale
,
output_scale
=
output_scale
,
input_scale
=
input_scale
,
input_scale
=
input_scale
,
is_sf_swizzled_layout
=
True
,
)
)
output_scale_view
=
torch
.
ops
.
aten
.
view
.
dtype
(
at2
[
2
],
FP8_DTYPE
)
output_scale_view
=
torch
.
ops
.
aten
.
view
.
dtype
(
at2
[
2
],
FP8_DTYPE
)
return
at2
[
1
],
output_scale_view
return
at2
[
1
],
output_scale_view
...
...
vllm/model_executor/layers/fused_moe/utils.py
View file @
fcb9df99
...
@@ -24,7 +24,6 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
...
@@ -24,7 +24,6 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_e4m3_quantize
,
mxfp8_e4m3_quantize
,
)
)
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.flashinfer
import
flashinfer_fp4_quantize
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
...
@@ -117,9 +116,7 @@ def _nvfp4_quantize(
...
@@ -117,9 +116,7 @@ def _nvfp4_quantize(
A_scale
:
torch
.
Tensor
|
None
,
A_scale
:
torch
.
Tensor
|
None
,
is_sf_swizzled_layout
:
bool
,
is_sf_swizzled_layout
:
bool
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
flashinfer_fp4_quantize
(
return
ops
.
scaled_fp4_quant
(
A
,
A_scale
,
is_sf_swizzled_layout
=
is_sf_swizzled_layout
)
A
,
A_scale
,
is_sf_swizzled_layout
=
is_sf_swizzled_layout
)
def
_fp8_quantize
(
def
_fp8_quantize
(
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py
View file @
fcb9df99
...
@@ -191,7 +191,10 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
...
@@ -191,7 +191,10 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4
,
x_blockscale
=
scaled_fp4_quant
(
x_fp4
,
x_blockscale
=
scaled_fp4_quant
(
x
,
layer
.
input_global_scale
,
self
.
backend
x
,
layer
.
input_global_scale
,
is_sf_swizzled_layout
=
True
,
backend
=
self
.
backend
,
)
)
mm_args
=
(
mm_args
=
(
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
fcb9df99
...
@@ -1307,7 +1307,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
...
@@ -1307,7 +1307,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
output_shape
=
[
x
.
shape
[
0
],
layer
.
weight
.
shape
[
0
]]
output_shape
=
[
x
.
shape
[
0
],
layer
.
weight
.
shape
[
0
]]
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4
,
x_blockscale
=
scaled_fp4_quant
(
x
,
layer
.
input_scale_inv
,
self
.
backend
)
x_fp4
,
x_blockscale
=
scaled_fp4_quant
(
x
,
layer
.
input_scale_inv
,
is_sf_swizzled_layout
=
True
,
backend
=
self
.
backend
)
# validate dtypes of quantized input, input block scale,
# validate dtypes of quantized input, input block scale,
# weight and weight_blockscale
# weight and weight_blockscale
...
...
vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
View file @
fcb9df99
...
@@ -8,6 +8,7 @@ import torch
...
@@ -8,6 +8,7 @@ import torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
(
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEConfig
,
...
@@ -341,10 +342,8 @@ def flashinfer_trtllm_fp4_moe(
...
@@ -341,10 +342,8 @@ def flashinfer_trtllm_fp4_moe(
hidden_states_fp4
,
hidden_states_scale_linear_fp4
=
x
hidden_states_fp4
,
hidden_states_scale_linear_fp4
=
x
else
:
else
:
# hidden_states is the already quantized
# hidden_states is the already quantized
(
hidden_states_fp4
,
hidden_states_scale_linear_fp4
)
=
flashinfer
.
fp4_quantize
(
(
hidden_states_fp4
,
hidden_states_scale_linear_fp4
)
=
ops
.
scaled_fp4_quant
(
x
,
x
,
layer
.
a1_gscale
,
is_sf_swizzled_layout
=
False
layer
.
a1_gscale
,
is_sf_swizzled_layout
=
False
,
)
)
# Determine routing method type
# Determine routing method type
...
@@ -443,10 +442,8 @@ def flashinfer_trtllm_fp4_routed_moe(
...
@@ -443,10 +442,8 @@ def flashinfer_trtllm_fp4_routed_moe(
hidden_states_fp4
,
hidden_states_scale_linear_fp4
=
x
hidden_states_fp4
,
hidden_states_scale_linear_fp4
=
x
else
:
else
:
# Quantize input to FP4
# Quantize input to FP4
(
hidden_states_fp4
,
hidden_states_scale_linear_fp4
)
=
flashinfer
.
fp4_quantize
(
(
hidden_states_fp4
,
hidden_states_scale_linear_fp4
)
=
ops
.
scaled_fp4_quant
(
x
,
x
,
layer
.
a1_gscale
,
is_sf_swizzled_layout
=
False
layer
.
a1_gscale
,
is_sf_swizzled_layout
=
False
,
)
)
# Call TRT-LLM FP4 block-scale MoE kernel
# Call TRT-LLM FP4 block-scale MoE kernel
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment