Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
774d0c01
Unverified
Commit
774d0c01
authored
Jul 22, 2025
by
Wentao Ye
Committed by
GitHub
Jul 22, 2025
Browse files
[Perf] Cuda Kernel for Per Token Group Quant (#21083)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
2c8db17c
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
285 additions
and
4 deletions
+285
-4
CMakeLists.txt
CMakeLists.txt
+1
-0
csrc/ops.h
csrc/ops.h
+5
-0
csrc/quantization/fp8/per_token_group_quant.cu
csrc/quantization/fp8/per_token_group_quant.cu
+213
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+9
-0
tests/kernels/quantization/test_per_token_group_quant.py
tests/kernels/quantization/test_per_token_group_quant.py
+44
-0
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+13
-4
No files found.
CMakeLists.txt
View file @
774d0c01
...
...
@@ -245,6 +245,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fp8/per_token_group_quant.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/activation_kernels.cu"
...
...
csrc/ops.h
View file @
774d0c01
...
...
@@ -297,6 +297,11 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch
::
Tensor
&
scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
azp
);
void
per_token_group_quant_fp8
(
const
torch
::
Tensor
&
input
,
torch
::
Tensor
&
output_q
,
torch
::
Tensor
&
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
,
bool
scale_ue8m0
);
torch
::
Tensor
gptq_gemm
(
torch
::
Tensor
a
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
b_gptq_qzeros
,
torch
::
Tensor
b_gptq_scales
,
torch
::
Tensor
b_g_idx
,
...
...
csrc/quantization/fp8/per_token_group_quant.cu
0 → 100644
View file @
774d0c01
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include <cmath>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <torch/all.h>
#include "../vectorization.cuh"
#include "../vectorization_utils.cuh"
#include "../../dispatch_utils.h"
__device__
__forceinline__
float
GroupReduceMax
(
float
val
,
const
int
tid
)
{
unsigned
mask
=
0xffff
;
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
8
));
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
4
));
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
2
));
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
mask
,
val
,
1
));
return
val
;
}
template
<
typename
T
,
typename
DST_DTYPE
,
bool
IS_COLUMN_MAJOR
=
false
,
bool
SCALE_UE8M0
=
false
,
typename
scale_packed_t
=
float
>
__global__
void
per_token_group_quant_8bit_kernel
(
const
T
*
__restrict__
input
,
void
*
__restrict__
output_q
,
scale_packed_t
*
__restrict__
output_s
,
const
int
group_size
,
const
int
num_groups
,
const
int
groups_per_block
,
const
float
eps
,
const
float
min_8bit
,
const
float
max_8bit
,
const
int
scale_num_rows
=
0
,
const
int
scale_stride
=
0
)
{
const
int
threads_per_group
=
16
;
const
int64_t
local_group_id
=
threadIdx
.
x
/
threads_per_group
;
const
int
lane_id
=
threadIdx
.
x
%
threads_per_group
;
const
int64_t
block_group_id
=
blockIdx
.
x
*
groups_per_block
;
const
int64_t
global_group_id
=
block_group_id
+
local_group_id
;
const
int64_t
block_group_offset
=
global_group_id
*
group_size
;
float
local_absmax
=
eps
;
using
scale_element_t
=
float
;
static_assert
(
sizeof
(
scale_packed_t
)
%
sizeof
(
scale_element_t
)
==
0
);
const
T
*
group_input
=
input
+
block_group_offset
;
DST_DTYPE
*
group_output
=
static_cast
<
DST_DTYPE
*>
(
output_q
)
+
block_group_offset
;
scale_element_t
*
scale_output
;
if
constexpr
(
IS_COLUMN_MAJOR
)
{
const
int
num_elems_per_pack
=
static_cast
<
int
>
(
sizeof
(
scale_packed_t
)
/
sizeof
(
scale_element_t
));
const
int
scale_num_rows_element
=
scale_num_rows
*
num_elems_per_pack
;
const
int
row_idx
=
global_group_id
/
scale_num_rows_element
;
const
int
col_idx_raw
=
global_group_id
%
scale_num_rows_element
;
const
int
col_idx
=
col_idx_raw
/
num_elems_per_pack
;
const
int
pack_idx
=
col_idx_raw
%
num_elems_per_pack
;
scale_output
=
reinterpret_cast
<
scale_element_t
*>
(
output_s
)
+
(
col_idx
*
scale_stride
*
num_elems_per_pack
+
row_idx
*
num_elems_per_pack
+
pack_idx
);
}
else
{
scale_output
=
output_s
+
global_group_id
;
}
// shared memory to cache each group's data to avoid double DRAM reads.
extern
__shared__
__align__
(
16
)
char
smem_raw
[];
T
*
smem
=
reinterpret_cast
<
T
*>
(
smem_raw
);
T
*
smem_group
=
smem
+
local_group_id
*
group_size
;
constexpr
int
vec_size
=
16
/
sizeof
(
T
);
using
vec_t
=
vllm
::
vec_n_t
<
T
,
vec_size
>
;
// copy global -> shared & compute absmax
auto
scalar_op_cache
=
[
&
]
__device__
(
T
&
dst
,
const
T
&
src
)
{
float
abs_v
=
fabsf
(
static_cast
<
float
>
(
src
));
local_absmax
=
fmaxf
(
local_absmax
,
abs_v
);
dst
=
src
;
};
vllm
::
vectorize_with_alignment
<
vec_size
>
(
group_input
,
// in
smem_group
,
// out (shared)
group_size
,
// elements per group
lane_id
,
// thread id
threads_per_group
,
// stride in group
scalar_op_cache
);
// scalar handler
local_absmax
=
GroupReduceMax
(
local_absmax
,
lane_id
);
float
y_s
=
local_absmax
/
max_8bit
;
if
constexpr
(
SCALE_UE8M0
)
{
y_s
=
exp2f
(
ceilf
(
log2f
(
fmaxf
(
fabsf
(
y_s
),
1e-10
f
))));
}
scale_element_t
y_s_quant
=
y_s
;
if
(
lane_id
==
0
)
{
*
scale_output
=
y_s_quant
;
}
__syncthreads
();
// quantize shared -> global 8-bit
auto
scalar_op_quant
=
[
&
]
__device__
(
DST_DTYPE
&
dst
,
const
T
&
src
)
{
float
q
=
fminf
(
fmaxf
(
static_cast
<
float
>
(
src
)
/
y_s
,
min_8bit
),
max_8bit
);
dst
=
DST_DTYPE
(
q
);
};
vllm
::
vectorize_with_alignment
<
vec_size
>
(
smem_group
,
// in (shared)
group_output
,
// out (global quant tensor)
group_size
,
// elements
lane_id
,
// tid
threads_per_group
,
// stride
scalar_op_quant
);
// scalar handler
}
void
per_token_group_quant_8bit
(
const
torch
::
Tensor
&
input
,
torch
::
Tensor
&
output_q
,
torch
::
Tensor
&
output_s
,
int64_t
group_size
,
double
eps
,
double
min_8bit
,
double
max_8bit
,
bool
scale_ue8m0
=
false
)
{
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
output_q
.
is_contiguous
());
const
int
num_groups
=
input
.
numel
()
/
group_size
;
TORCH_CHECK
(
input
.
numel
()
%
group_size
==
0
);
TORCH_CHECK
(
output_s
.
dim
()
==
2
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
constexpr
int
THREADS_PER_GROUP
=
16
;
int
groups_per_block
=
1
;
if
(
num_groups
%
16
==
0
)
{
groups_per_block
=
16
;
}
else
if
(
num_groups
%
8
==
0
)
{
groups_per_block
=
8
;
}
else
if
(
num_groups
%
4
==
0
)
{
groups_per_block
=
4
;
}
else
if
(
num_groups
%
2
==
0
)
{
groups_per_block
=
2
;
}
auto
dst_type
=
output_q
.
scalar_type
();
const
int
num_blocks
=
num_groups
/
groups_per_block
;
const
int
num_threads
=
groups_per_block
*
THREADS_PER_GROUP
;
const
bool
is_column_major
=
output_s
.
stride
(
0
)
<
output_s
.
stride
(
1
);
const
int
scale_num_rows
=
output_s
.
size
(
1
);
const
int
scale_stride
=
output_s
.
stride
(
1
);
#define LAUNCH_KERNEL(T, DST_DTYPE) \
do { \
dim3 grid(num_blocks); \
dim3 block(num_threads); \
size_t smem_bytes = \
static_cast<size_t>(groups_per_block) * group_size * sizeof(T); \
if (is_column_major) { \
if (scale_ue8m0) { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, true> \
<<<grid, block, smem_bytes, stream>>>( \
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), group_size, \
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
(float)max_8bit, scale_num_rows, scale_stride); \
} else { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false> \
<<<grid, block, smem_bytes, stream>>>( \
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), group_size, \
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
(float)max_8bit, scale_num_rows, scale_stride); \
} \
} else { \
if (scale_ue8m0) { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false, true> \
<<<grid, block, smem_bytes, stream>>>( \
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), group_size, \
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
(float)max_8bit); \
} else { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false, false> \
<<<grid, block, smem_bytes, stream>>>( \
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), group_size, \
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
(float)max_8bit); \
} \
} \
} while (0)
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"per_token_group_quant_8bit"
,
([
&
]
{
if
(
dst_type
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
LAUNCH_KERNEL
(
scalar_t
,
c10
::
Float8_e4m3fn
);
}
}));
#undef LAUNCH_KERNEL
}
void
per_token_group_quant_fp8
(
const
torch
::
Tensor
&
input
,
torch
::
Tensor
&
output_q
,
torch
::
Tensor
&
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
,
bool
scale_ue8m0
)
{
per_token_group_quant_8bit
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
scale_ue8m0
);
}
csrc/torch_bindings.cpp
View file @
774d0c01
...
...
@@ -601,6 +601,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
impl
(
"dynamic_scaled_int8_quant"
,
torch
::
kCUDA
,
&
dynamic_scaled_int8_quant
);
// Compute per-token-group FP8 quantized tensor and scaling factor.
ops
.
def
(
"per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! "
"output_s, "
"int group_size, float eps, float fp8_min, float fp8_max, bool "
"scale_ue8m0) -> ()"
);
ops
.
impl
(
"per_token_group_fp8_quant"
,
torch
::
kCUDA
,
&
per_token_group_quant_fp8
);
// Mamba selective scan kernel
ops
.
def
(
"selective_scan_fwd(Tensor! u, Tensor! delta,"
...
...
tests/kernels/quantization/test_per_token_group_quant.py
0 → 100644
View file @
774d0c01
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
unittest.mock
import
patch
import
pytest
import
torch
from
vllm.model_executor.layers.quantization.utils
import
fp8_utils
@
pytest
.
mark
.
parametrize
(
"shape"
,
[(
32
,
128
),
(
64
,
256
),
(
16
,
512
)])
@
pytest
.
mark
.
parametrize
(
"column_major"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"scale_ue8m0"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
64
,
128
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"CUDA not available"
)
def
test_per_token_group_quant_fp8
(
shape
,
column_major
:
bool
,
scale_ue8m0
:
bool
,
group_size
:
int
):
device
=
"cuda"
torch
.
manual_seed
(
42
)
num_tokens
,
hidden_dim
=
shape
x
=
(
torch
.
randn
(
(
num_tokens
,
hidden_dim
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
*
8
)
# cuda path
out_q
,
scale
=
fp8_utils
.
per_token_group_quant_fp8
(
x
,
group_size
,
column_major_scales
=
column_major
,
use_ue8m0
=
scale_ue8m0
,
)
# triton ref
with
patch
(
"vllm.platforms.current_platform.is_cuda"
,
return_value
=
False
):
ref_q
,
ref_s
=
fp8_utils
.
per_token_group_quant_fp8
(
x
,
group_size
,
column_major_scales
=
column_major
,
use_ue8m0
=
scale_ue8m0
,
)
assert
torch
.
allclose
(
out_q
.
float
(),
ref_q
.
float
(),
atol
=
0.15
,
rtol
=
0.15
)
assert
torch
.
allclose
(
scale
,
ref_s
,
atol
=
0.01
,
rtol
=
0.01
)
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
774d0c01
...
...
@@ -366,6 +366,7 @@ def per_token_group_quant_fp8(
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
column_major_scales
:
bool
=
False
,
out_q
:
Optional
[
torch
.
Tensor
]
=
None
,
use_ue8m0
:
bool
=
is_blackwell_deep_gemm_used
(),
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
...
...
@@ -397,8 +398,7 @@ def per_token_group_quant_fp8(
if
x_q
is
None
:
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
# Allocate the scale tensor in either row- or column-major format.
if
column_major_scales
:
shape
=
(
x
.
shape
[
-
1
]
//
group_size
,
)
+
x
.
shape
[:
-
1
]
x_s
=
torch
.
empty
(
shape
,
device
=
x
.
device
,
...
...
@@ -407,6 +407,15 @@ def per_token_group_quant_fp8(
shape
=
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,
)
x_s
=
torch
.
empty
(
shape
,
device
=
x
.
device
,
dtype
=
torch
.
float32
)
# prefer CUDA kernel if available
if
current_platform
.
is_cuda
()
and
x
.
is_contiguous
():
torch
.
ops
.
_C
.
per_token_group_fp8_quant
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
,
use_ue8m0
)
return
x_q
,
x_s
# TRITON FALLBACK
M
=
x
.
numel
()
//
group_size
N
=
group_size
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
...
...
@@ -423,7 +432,7 @@ def per_token_group_quant_fp8(
eps
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
use_ue8m0
=
is_blackwell_deep_gemm_used
()
,
use_ue8m0
=
use_ue8m0
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
...
...
@@ -439,7 +448,7 @@ def per_token_group_quant_fp8(
eps
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
use_ue8m0
=
is_blackwell_deep_gemm_used
()
,
use_ue8m0
=
use_ue8m0
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment