Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
75d29cf4
Unverified
Commit
75d29cf4
authored
Jul 25, 2025
by
Wentao Ye
Committed by
GitHub
Jul 25, 2025
Browse files
[Perf] Cuda Kernel for Int8 Per Token Group Quant (#21476)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
41d3082c
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
47 additions
and
3 deletions
+47
-3
csrc/ops.h
csrc/ops.h
+5
-0
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
+10
-0
csrc/quantization/fp8/per_token_group_quant.cu
csrc/quantization/fp8/per_token_group_quant.cu
+5
-1
csrc/quantization/per_token_group_quant_8bit.h
csrc/quantization/per_token_group_quant_8bit.h
+10
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+8
-0
vllm/model_executor/layers/quantization/utils/int8_utils.py
vllm/model_executor/layers/quantization/utils/int8_utils.py
+9
-2
No files found.
csrc/ops.h
View file @
75d29cf4
...
@@ -292,6 +292,11 @@ void per_token_group_quant_fp8(const torch::Tensor& input,
...
@@ -292,6 +292,11 @@ void per_token_group_quant_fp8(const torch::Tensor& input,
torch
::
Tensor
&
output_q
,
torch
::
Tensor
&
output_s
,
torch
::
Tensor
&
output_q
,
torch
::
Tensor
&
output_s
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
int64_t
group_size
,
double
eps
,
double
fp8_min
,
double
fp8_max
,
bool
scale_ue8m0
);
double
fp8_max
,
bool
scale_ue8m0
);
void
per_token_group_quant_int8
(
const
torch
::
Tensor
&
input
,
torch
::
Tensor
&
output_q
,
torch
::
Tensor
&
output_s
,
int64_t
group_size
,
double
eps
,
double
int8_min
,
double
int8_max
);
#endif
#endif
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
...
...
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
View file @
75d29cf4
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <torch/all.h>
#include "../per_token_group_quant_8bit.h"
#include <cmath>
#include <cmath>
#include "../../dispatch_utils.h"
#include "../../dispatch_utils.h"
...
@@ -336,3 +338,11 @@ void dynamic_scaled_int8_quant(
...
@@ -336,3 +338,11 @@ void dynamic_scaled_int8_quant(
}
}
});
});
}
}
void
per_token_group_quant_int8
(
const
torch
::
Tensor
&
input
,
torch
::
Tensor
&
output_q
,
torch
::
Tensor
&
output_s
,
int64_t
group_size
,
double
eps
,
double
int8_min
,
double
int8_max
)
{
per_token_group_quant_8bit
(
input
,
output_q
,
output_s
,
group_size
,
eps
,
int8_min
,
int8_max
);
}
\ No newline at end of file
csrc/quantization/fp8/per_token_group_quant.cu
View file @
75d29cf4
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fn.h>
#include "../per_token_group_quant_8bit.h"
#include <cmath>
#include <cmath>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
...
@@ -120,7 +122,7 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
...
@@ -120,7 +122,7 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
torch
::
Tensor
&
output_q
,
torch
::
Tensor
&
output_q
,
torch
::
Tensor
&
output_s
,
int64_t
group_size
,
torch
::
Tensor
&
output_s
,
int64_t
group_size
,
double
eps
,
double
min_8bit
,
double
max_8bit
,
double
eps
,
double
min_8bit
,
double
max_8bit
,
bool
scale_ue8m0
=
false
)
{
bool
scale_ue8m0
)
{
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
output_q
.
is_contiguous
());
TORCH_CHECK
(
output_q
.
is_contiguous
());
...
@@ -198,6 +200,8 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
...
@@ -198,6 +200,8 @@ void per_token_group_quant_8bit(const torch::Tensor& input,
input
.
scalar_type
(),
"per_token_group_quant_8bit"
,
([
&
]
{
input
.
scalar_type
(),
"per_token_group_quant_8bit"
,
([
&
]
{
if
(
dst_type
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
if
(
dst_type
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
LAUNCH_KERNEL
(
scalar_t
,
c10
::
Float8_e4m3fn
);
LAUNCH_KERNEL
(
scalar_t
,
c10
::
Float8_e4m3fn
);
}
else
if
(
dst_type
==
at
::
ScalarType
::
Char
)
{
LAUNCH_KERNEL
(
scalar_t
,
int8_t
);
}
}
}));
}));
...
...
csrc/quantization/per_token_group_quant_8bit.h
0 → 100644
View file @
75d29cf4
#pragma once
#include <torch/all.h>
// TODO(wentao): refactor the folder to 8bit, then includes fp8 and int8 folders
// 8-bit per-token-group quantization helper used by both FP8 and INT8
void
per_token_group_quant_8bit
(
const
torch
::
Tensor
&
input
,
torch
::
Tensor
&
output_q
,
torch
::
Tensor
&
output_s
,
int64_t
group_size
,
double
eps
,
double
min_8bit
,
double
max_8bit
,
bool
scale_ue8m0
=
false
);
\ No newline at end of file
csrc/torch_bindings.cpp
View file @
75d29cf4
...
@@ -624,6 +624,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -624,6 +624,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
impl
(
"per_token_group_fp8_quant"
,
torch
::
kCUDA
,
ops
.
impl
(
"per_token_group_fp8_quant"
,
torch
::
kCUDA
,
&
per_token_group_quant_fp8
);
&
per_token_group_quant_fp8
);
// Compute per-token-group INT8 quantized tensor and scaling factor.
ops
.
def
(
"per_token_group_quant_int8(Tensor input, Tensor! output_q, Tensor! "
"output_s, int group_size, float eps, float int8_min, float int8_max) -> "
"()"
);
ops
.
impl
(
"per_token_group_quant_int8"
,
torch
::
kCUDA
,
&
per_token_group_quant_int8
);
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
ops
.
def
(
ops
.
def
(
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "
...
...
vllm/model_executor/layers/quantization/utils/int8_utils.py
View file @
75d29cf4
...
@@ -238,13 +238,20 @@ def per_token_group_quant_int8(
...
@@ -238,13 +238,20 @@ def per_token_group_quant_int8(
int8_min
=
iinfo
.
min
int8_min
=
iinfo
.
min
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
x_s
=
torch
.
empty
(
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,
),
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,
),
device
=
x
.
device
,
device
=
x
.
device
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
)
)
# prefer CUDA kernel if available
if
current_platform
.
is_cuda
():
torch
.
ops
.
_C
.
per_token_group_quant_int8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
float
(
int8_min
),
float
(
int8_max
))
return
x_q
,
x_s
M
=
x
.
numel
()
//
group_size
N
=
group_size
BLOCK
=
triton
.
next_power_of_2
(
N
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
# heuristics for number of warps
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment