Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b6efafd9
Unverified
Commit
b6efafd9
authored
Jun 12, 2025
by
Wentao Ye
Committed by
GitHub
Jun 12, 2025
Browse files
[Perf] Vectorize static / dynamic INT8 quant kernels (#19233)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
1129e2b1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
415 additions
and
101 deletions
+415
-101
benchmarks/kernels/bench_int8_gemm.py
benchmarks/kernels/bench_int8_gemm.py
+200
-0
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
+139
-101
csrc/quantization/vectorization_utils.cuh
csrc/quantization/vectorization_utils.cuh
+75
-0
tests/kernels/quantization/test_int8_quant.py
tests/kernels/quantization/test_int8_quant.py
+1
-0
No files found.
benchmarks/kernels/bench_int8_gemm.py
0 → 100644
View file @
b6efafd9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
copy
import
itertools
import
torch
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
from
vllm._custom_ops
import
scaled_int8_quant
as
vllm_scaled_int8_quant
from
vllm.triton_utils
import
triton
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"torch-bf16"
,
# "int8-tensor-w-token-a",
"int8-tensor-w-tensor-a"
,
"int8-channel-w-token-a"
,
# "int8-channel-w-tensor-a",
# "int8-tensor-w-token-a-noquant",
"int8-tensor-w-tensor-a-noquant"
,
"int8-channel-w-token-a-noquant"
,
# "int8-channel-w-tensor-a-noquant",
],
line_names
=
[
"torch-bf16"
,
# "int8-tensor-w-token-a",
"int8-tensor-w-tensor-a"
,
"int8-channel-w-token-a"
,
# "int8-channel-w-tensor-a",
# "int8-tensor-w-token-a-noquant",
"int8-tensor-w-tensor-a-noquant"
,
"int8-channel-w-token-a-noquant"
,
# "int8-channel-w-tensor-a-noquant",
],
ylabel
=
"TFLOP/s (larger is better)"
,
plot_name
=
"BF16 vs INT8 GEMMs"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
):
M
=
batch_size
device
=
"cuda"
dtype
=
torch
.
bfloat16
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
b
=
torch
.
randn
((
N
,
K
),
device
=
device
,
dtype
=
dtype
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
"torch-bf16"
in
provider
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
torch
.
nn
.
functional
.
linear
(
a
,
b
),
quantiles
=
quantiles
)
elif
"int8"
in
provider
:
# Weights are always quantized ahead of time
if
"noquant"
in
provider
:
# For "no quant", we don't measure the time for activations
if
"tensor-w-token-a"
in
provider
:
# Dynamic per-token quant for A, static per-tensor quant for B
scale_b
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
,
scale_b
)
assert
scale_b_int8
.
numel
()
==
1
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
)
elif
"tensor-w-tensor-a"
in
provider
:
# Static per-tensor quantization with fixed scales for both A and B
scale_a
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
,
scale_b
)
assert
scale_b_int8
.
numel
()
==
1
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
,
scale_a
)
elif
"channel-w-token-a"
in
provider
:
# Dynamic per-channel quantization for weights, per-token quant for A
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
)
assert
scale_b_int8
.
numel
()
==
N
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
)
elif
"channel-w-tensor-a"
in
provider
:
# Dynamic per-channel quantization for weights, per-tensor quant for A
scale_a
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
)
assert
scale_b_int8
.
numel
()
==
N
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
,
scale_a
)
def
run_quant
():
return
vllm_scaled_mm
(
a_int8
,
b_int8
,
scale_a_int8
,
scale_b_int8
,
dtype
)
else
:
# Quantize the activations during the GEMM call
if
"tensor-w-token-a"
in
provider
:
# Dynamic per-token quant for A, static per-tensor quant for B
scale_b
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
,
scale_b
)
assert
scale_b_int8
.
numel
()
==
1
def
run_quant
():
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
)
return
vllm_scaled_mm
(
a_int8
,
b_int8
,
scale_a_int8
,
scale_b_int8
,
dtype
)
elif
"tensor-w-tensor-a"
in
provider
:
# Static per-tensor quantization with fixed scales for both A and B
scale_a
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
,
scale_b
)
assert
scale_b_int8
.
numel
()
==
1
def
run_quant
():
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
,
scale_a
)
return
vllm_scaled_mm
(
a_int8
,
b_int8
,
scale_a_int8
,
scale_b_int8
,
dtype
)
elif
"channel-w-token-a"
in
provider
:
# Dynamic per-channel quant for weights, per-token quant for A
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
)
assert
scale_b_int8
.
numel
()
==
N
def
run_quant
():
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
)
return
vllm_scaled_mm
(
a_int8
,
b_int8
,
scale_a_int8
,
scale_b_int8
,
dtype
)
elif
"channel-w-tensor-a"
in
provider
:
# Dynamic per-channel quant for weights, static per-tensor quant for A
scale_a
=
torch
.
tensor
([
1.0
],
device
=
device
,
dtype
=
torch
.
float32
)
b_int8
,
scale_b_int8
,
_
=
vllm_scaled_int8_quant
(
b
)
assert
scale_b_int8
.
numel
()
==
N
def
run_quant
():
a_int8
,
scale_a_int8
,
_
=
vllm_scaled_int8_quant
(
a
,
scale_a
)
return
vllm_scaled_mm
(
a_int8
,
b_int8
,
scale_a_int8
,
scale_b_int8
,
dtype
)
b_int8
=
b_int8
.
t
()
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
run_quant
(),
quantiles
=
quantiles
)
# Calculate TFLOP/s, two flops per multiply-add
tflops
=
lambda
ms
:
(
2
*
M
*
N
*
K
)
*
1e-12
/
(
ms
*
1e-3
)
return
tflops
(
ms
),
tflops
(
max_ms
),
tflops
(
min_ms
)
def
prepare_shapes
(
args
):
KN_model_names
=
[]
models_tps
=
list
(
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
))
for
model
,
tp_size
in
models_tps
:
assert
model
in
WEIGHT_SHAPES
for
KN
,
tp_split_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_split_dim
]
=
KN
[
tp_split_dim
]
//
tp_size
KN
.
append
(
model
)
KN_model_names
.
append
(
KN
)
return
KN_model_names
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.1-8B-Instruct"
],
choices
=
[
*
WEIGHT_SHAPES
.
keys
()],
help
=
"List of models to benchmark"
,
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
],
help
=
"List of tensor parallel sizes"
,
)
args
=
parser
.
parse_args
()
KN_model_names
=
prepare_shapes
(
args
)
for
K
,
N
,
model_name
in
KN_model_names
:
print
(
f
"
{
model_name
}
, N=
{
N
}
K=
{
K
}
, BF16 vs INT8 GEMMs TFLOP/s:"
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
f
"bench_int8_res_n
{
N
}
_k
{
K
}
"
,
N
=
N
,
K
=
K
,
)
print
(
"Benchmark finished!"
)
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
View file @
b6efafd9
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <cmath>
#include "../../dispatch_utils.h"
#include "../vectorization_utils.cuh"
#ifndef USE_ROCM
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#include <cub/util_type.cuh>
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#include <hipcub/util_type.hpp>
#endif
static
inline
__device__
int8_t
float_to_int8_rn
(
float
x
)
{
...
...
@@ -103,134 +105,170 @@ static inline __device__ int8_t int32_to_int8(int32_t x) {
namespace
vllm
{
template
<
typename
scalar_t
,
typename
scale_t
ype
>
template
<
typename
scalar_t
,
typename
scale_t
>
__global__
void
static_scaled_int8_quant_kernel
(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
const
*
scale_ptr
,
const
int
hidden_size
)
{
int
const
tid
=
threadIdx
.
x
;
int64_t
const
token_idx
=
blockIdx
.
x
;
scale_type
const
scale
=
*
scale_ptr
;
const
scalar_t
*
__restrict__
input
,
int8_t
*
__restrict__
output
,
const
scale_t
*
scale_ptr
,
const
int
hidden_size
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
stride
=
blockDim
.
x
;
const
int64_t
token_idx
=
blockIdx
.
x
;
const
float
scale
=
*
scale_ptr
;
// Must be performed using 64-bit math to avoid integer overflow.
o
ut
+
=
token_idx
*
hidden_size
;
input
+
=
token_idx
*
hidden_size
;
const
scalar_t
*
row_in
=
inp
ut
+
token_idx
*
hidden_size
;
in
t8_t
*
row_out
=
out
put
+
token_idx
*
hidden_size
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
out
[
i
]
=
float_to_int8_rn
(
static_cast
<
float
>
(
input
[
i
])
/
scale
);
}
vectorize_with_alignment
<
16
>
(
row_in
,
row_out
,
hidden_size
,
tid
,
stride
,
[
=
]
__device__
(
int8_t
&
dst
,
const
scalar_t
&
src
)
{
dst
=
float_to_int8_rn
(
static_cast
<
float
>
(
src
)
/
scale
);
});
}
template
<
typename
scalar_t
,
typename
scale_t
ype
,
typename
azp_t
ype
>
template
<
typename
scalar_t
,
typename
scale_t
,
typename
azp_t
>
__global__
void
static_scaled_int8_azp_quant_kernel
(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
const
*
scale_ptr
,
azp_type
const
*
azp_ptr
,
const
int
hidden_size
)
{
int
const
tid
=
threadIdx
.
x
;
int64_t
const
token_idx
=
blockIdx
.
x
;
scale_type
const
scale
=
*
scale_ptr
;
azp_type
const
azp
=
*
azp_ptr
;
const
scalar_t
*
__restrict__
input
,
int8_t
*
__restrict__
output
,
const
scale_t
*
scale_ptr
,
const
azp_t
*
azp_ptr
,
const
int
hidden_size
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
stride
=
blockDim
.
x
;
const
int64_t
token_idx
=
blockIdx
.
x
;
const
float
scale
=
*
scale_ptr
;
const
azp_t
azp
=
*
azp_ptr
;
const
float
inv_s
=
1.0
f
/
scale
;
// Must be performed using 64-bit math to avoid integer overflow.
out
+=
token_idx
*
hidden_size
;
input
+=
token_idx
*
hidden_size
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
auto
const
val
=
static_cast
<
float
>
(
input
[
i
]);
auto
const
quant_val
=
int32_to_int8
(
float_to_int32_rn
(
val
/
scale
)
+
azp
);
out
[
i
]
=
quant_val
;
}
const
scalar_t
*
row_in
=
input
+
token_idx
*
hidden_size
;
int8_t
*
row_out
=
output
+
token_idx
*
hidden_size
;
vectorize_with_alignment
<
16
>
(
row_in
,
row_out
,
hidden_size
,
tid
,
stride
,
[
=
]
__device__
(
int8_t
&
dst
,
const
scalar_t
&
src
)
{
const
auto
v
=
static_cast
<
float
>
(
src
)
*
inv_s
;
dst
=
int32_to_int8
(
float_to_int32_rn
(
v
)
+
azp
);
});
}
template
<
typename
scalar_t
,
typename
scale_t
ype
>
template
<
typename
scalar_t
,
typename
scale_t
>
__global__
void
dynamic_scaled_int8_quant_kernel
(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
*
scale
,
const
int
hidden_size
)
{
int
const
tid
=
threadIdx
.
x
;
int64_t
const
token_idx
=
blockIdx
.
x
;
float
absmax_val
=
0.0
f
;
float
const
zero
=
0.0
f
;
const
scalar_t
*
__restrict__
input
,
int8_t
*
__restrict__
output
,
scale_t
*
scale_out
,
const
int
hidden_size
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
stride
=
blockDim
.
x
;
const
int64_t
token_idx
=
blockIdx
.
x
;
// Must be performed using 64-bit math to avoid integer overflow.
out
+=
token_idx
*
hidden_size
;
input
+=
token_idx
*
hidden_size
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
float
val
=
static_cast
<
float
>
(
input
[
i
]);
val
=
val
>
zero
?
val
:
-
val
;
absmax_val
=
val
>
absmax_val
?
val
:
absmax_val
;
const
scalar_t
*
row_in
=
input
+
token_idx
*
hidden_size
;
int8_t
*
row_out
=
output
+
token_idx
*
hidden_size
;
// calculate for absmax
float
thread_max
=
0.
f
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
stride
)
{
const
auto
v
=
fabsf
(
static_cast
<
float
>
(
row_in
[
i
]));
thread_max
=
fmaxf
(
thread_max
,
v
);
}
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStorage
;
float
const
block_absmax_val_maybe
=
BlockReduce
(
reduceStorage
).
Reduce
(
absmax_val
,
cub
::
Max
{},
blockDim
.
x
);
__shared__
float
block_absmax_val
;
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
256
>
;
__shared__
typename
BlockReduce
::
TempStorage
tmp
;
float
block_max
=
BlockReduce
(
tmp
).
Reduce
(
thread_max
,
cub
::
Max
{},
blockDim
.
x
);
__shared__
float
absmax
;
if
(
tid
==
0
)
{
block_
absmax
_val
=
block_
absmax_val_maybe
;
scale
[
token_idx
]
=
block_
absmax
_val
/
127.
0
f
;
absmax
=
block_
max
;
scale
_out
[
blockIdx
.
x
]
=
absmax
/
127.
f
;
}
__syncthreads
();
float
const
tmp_scale
=
127.0
f
/
block_absmax_val
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
out
[
i
]
=
float_to_int8_rn
(
static_cast
<
float
>
(
input
[
i
])
*
tmp_scale
);
float
inv_s
=
(
absmax
==
0.
f
)
?
0.
f
:
127.
f
/
absmax
;
// 2. quantize
vectorize_with_alignment
<
16
>
(
row_in
,
row_out
,
hidden_size
,
tid
,
stride
,
[
=
]
__device__
(
int8_t
&
dst
,
const
scalar_t
&
src
)
{
dst
=
float_to_int8_rn
(
static_cast
<
float
>
(
src
)
*
inv_s
);
});
}
// MinMax structure to hold min and max values in one go
struct
MinMax
{
float
min
,
max
;
__host__
__device__
MinMax
()
:
min
(
std
::
numeric_limits
<
float
>::
max
()),
max
(
std
::
numeric_limits
<
float
>::
lowest
())
{}
__host__
__device__
explicit
MinMax
(
float
v
)
:
min
(
v
),
max
(
v
)
{}
// add a value to the MinMax
__host__
__device__
MinMax
&
operator
+=
(
float
v
)
{
min
=
fminf
(
min
,
v
);
max
=
fmaxf
(
max
,
v
);
return
*
this
;
}
// merge two MinMax objects
__host__
__device__
MinMax
&
operator
&=
(
const
MinMax
&
other
)
{
min
=
fminf
(
min
,
other
.
min
);
max
=
fmaxf
(
max
,
other
.
max
);
return
*
this
;
}
};
__host__
__device__
inline
MinMax
operator
+
(
MinMax
a
,
float
v
)
{
return
a
+=
v
;
}
__host__
__device__
inline
MinMax
operator
&
(
MinMax
a
,
const
MinMax
&
b
)
{
return
a
&=
b
;
}
template
<
typename
scalar_t
,
typename
scale_t
ype
,
typename
azp_t
ype
>
template
<
typename
scalar_t
,
typename
scale_t
,
typename
azp_t
>
__global__
void
dynamic_scaled_int8_azp_quant_kernel
(
scalar_t
const
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
*
scale
,
azp_type
*
azp
,
const
int
hidden_size
)
{
int64_t
const
token_idx
=
blockIdx
.
x
;
const
scalar_t
*
__restrict__
input
,
int8_t
*
__restrict__
output
,
scale_t
*
scale_out
,
azp_t
*
azp_out
,
const
int
hidden_size
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
stride
=
blockDim
.
x
;
const
int64_t
token_idx
=
blockIdx
.
x
;
// Must be performed using 64-bit math to avoid integer overflow.
out
+=
token_idx
*
hidden_size
;
input
+=
token_idx
*
hidden_size
;
// Scan for the min and max value for this token
float
max_val
=
std
::
numeric_limits
<
float
>::
min
();
float
min_val
=
std
::
numeric_limits
<
float
>::
max
();
for
(
int
i
=
threadIdx
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
auto
val
=
static_cast
<
float
>
(
input
[
i
]);
max_val
=
std
::
max
(
max_val
,
val
);
min_val
=
std
::
min
(
min_val
,
val
);
}
const
scalar_t
*
row_in
=
input
+
token_idx
*
hidden_size
;
int8_t
*
row_out
=
output
+
token_idx
*
hidden_size
;
// Reduce the max and min values across the block
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStorage
;
max_val
=
BlockReduce
(
reduceStorage
).
Reduce
(
max_val
,
cub
::
Max
{},
blockDim
.
x
);
__syncthreads
();
// Make sure min doesn't mess with max shared memory
min_val
=
BlockReduce
(
reduceStorage
).
Reduce
(
min_val
,
cub
::
Min
{},
blockDim
.
x
);
__shared__
scale_type
scale_sh
;
__shared__
azp_type
azp_sh
;
// Compute the scale and zero point and store them, only on the first thread
if
(
threadIdx
.
x
==
0
)
{
float
const
scale_val
=
(
max_val
-
min_val
)
/
255.0
f
;
// Use rounding to even (same as torch.round)
auto
const
azp_float
=
std
::
nearbyint
(
-
128.0
f
-
min_val
/
scale_val
);
auto
const
azp_val
=
static_cast
<
azp_type
>
(
azp_float
);
// Store the scale and azp into shared and global
scale
[
token_idx
]
=
scale_sh
=
scale_val
;
azp
[
token_idx
]
=
azp_sh
=
azp_val
;
// 1. calculate min & max
MinMax
thread_mm
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
stride
)
{
thread_mm
+=
static_cast
<
float
>
(
row_in
[
i
]);
}
// Wait for the scale and azp to be computed
__s
yncthreads
()
;
using
BlockReduce
=
cub
::
BlockReduce
<
MinMax
,
256
>
;
__s
hared__
typename
BlockReduce
::
TempStorage
tmp
;
float
const
scale_val
=
scale_sh
;
azp_type
const
azp_val
=
azp_sh
;
MinMax
mm
=
BlockReduce
(
tmp
).
Reduce
(
thread_mm
,
[]
__device__
(
MinMax
a
,
const
MinMax
&
b
)
{
a
&=
b
;
return
a
;
},
blockDim
.
x
);
// Quantize the values
for
(
int
i
=
threadIdx
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
auto
const
val
=
static_cast
<
float
>
(
input
[
i
]);
auto
const
quant_val
=
int32_to_int8
(
float_to_int32_rn
(
val
/
scale_val
)
+
azp_val
);
out
[
i
]
=
quant_val
;
__shared__
float
scale_sh
;
__shared__
azp_t
azp_sh
;
if
(
tid
==
0
)
{
float
s
=
(
mm
.
max
-
mm
.
min
)
/
255.
f
;
float
zp
=
nearbyintf
(
-
128.
f
-
mm
.
min
/
s
);
// round-to-even
scale_sh
=
s
;
azp_sh
=
azp_t
(
zp
);
scale_out
[
blockIdx
.
x
]
=
s
;
azp_out
[
blockIdx
.
x
]
=
azp_sh
;
}
__syncthreads
();
const
float
inv_s
=
1.
f
/
scale_sh
;
const
azp_t
azp
=
azp_sh
;
// 2. quantize
vectorize_with_alignment
<
16
>
(
row_in
,
row_out
,
hidden_size
,
tid
,
stride
,
[
=
]
__device__
(
int8_t
&
dst
,
const
scalar_t
&
src
)
{
const
auto
v
=
static_cast
<
float
>
(
src
)
*
inv_s
;
dst
=
int32_to_int8
(
float_to_int32_rn
(
v
)
+
azp
);
});
}
}
// namespace vllm
...
...
@@ -247,7 +285,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
int
const
hidden_size
=
input
.
size
(
-
1
);
int
const
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
const
grid
(
num_tokens
);
dim3
const
block
(
std
::
min
(
hidden_size
,
1024
));
dim3
const
block
(
std
::
min
(
hidden_size
,
256
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"static_scaled_int8_quant_kernel"
,
[
&
]
{
...
...
@@ -278,7 +316,7 @@ void dynamic_scaled_int8_quant(
int
const
hidden_size
=
input
.
size
(
-
1
);
int
const
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
const
grid
(
num_tokens
);
dim3
const
block
(
std
::
min
(
hidden_size
,
1024
));
dim3
const
block
(
std
::
min
(
hidden_size
,
256
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"dynamic_scaled_int8_quant_kernel"
,
[
&
]
{
...
...
csrc/quantization/vectorization_utils.cuh
0 → 100644
View file @
b6efafd9
#pragma once
#include "vectorization.cuh"
namespace
vllm
{
template
<
int
VEC_SIZE
,
typename
InT
,
typename
OutT
,
typename
ScaOp
>
struct
DefaultVecOp
{
ScaOp
scalar_op
;
__device__
__forceinline__
void
operator
()(
vec_n_t
<
OutT
,
VEC_SIZE
>&
dst
,
const
vec_n_t
<
InT
,
VEC_SIZE
>&
src
)
const
{
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
scalar_op
(
dst
.
val
[
i
],
src
.
val
[
i
]);
}
}
};
template
<
int
VEC_SIZE
,
typename
InT
,
typename
OutT
,
typename
VecOp
,
typename
ScaOp
>
__device__
inline
void
vectorize_with_alignment
(
const
InT
*
in
,
OutT
*
out
,
int
len
,
int
tid
,
int
stride
,
VecOp
&&
vec_op
,
// vec_n_t<InT,16> -> vec_n_t<OutT,16>
ScaOp
&&
scalar_op
)
{
// InT -> OutT
static_assert
(
VEC_SIZE
>
0
&&
(
VEC_SIZE
&
(
VEC_SIZE
-
1
))
==
0
,
"VEC_SIZE must be a positive power-of-two"
);
constexpr
int
WIDTH
=
VEC_SIZE
*
sizeof
(
InT
);
// eg: 64 B
uintptr_t
addr
=
reinterpret_cast
<
uintptr_t
>
(
in
);
int
misalignment_offset
=
addr
&
(
WIDTH
-
1
);
// addr % 64
int
alignment_bytes
=
WIDTH
-
misalignment_offset
;
// 64 - (addr % 64)
int
prefix_elems
=
alignment_bytes
&
(
WIDTH
-
1
);
// handle 64
prefix_elems
/=
sizeof
(
InT
);
prefix_elems
=
min
(
prefix_elems
,
len
);
// 0 ≤ prefix < 16
// 1. prefill the when it is unsafe to vectorize
for
(
int
i
=
tid
;
i
<
prefix_elems
;
i
+=
stride
)
{
scalar_op
(
out
[
i
],
in
[
i
]);
}
in
+=
prefix_elems
;
out
+=
prefix_elems
;
len
-=
prefix_elems
;
int
num_vec
=
len
/
VEC_SIZE
;
using
vin_t
=
vec_n_t
<
InT
,
VEC_SIZE
>
;
using
vout_t
=
vec_n_t
<
OutT
,
VEC_SIZE
>
;
auto
*
v_in
=
reinterpret_cast
<
const
vin_t
*>
(
in
);
auto
*
v_out
=
reinterpret_cast
<
vout_t
*>
(
out
);
// 2. vectorize the main part
for
(
int
i
=
tid
;
i
<
num_vec
;
i
+=
stride
)
{
vout_t
tmp
;
vec_op
(
tmp
,
v_in
[
i
]);
v_out
[
i
]
=
tmp
;
}
// 3. handle the tail
int
tail_start
=
num_vec
*
VEC_SIZE
;
for
(
int
i
=
tid
+
tail_start
;
i
<
len
;
i
+=
stride
)
{
scalar_op
(
out
[
i
],
in
[
i
]);
}
}
template
<
int
VEC_SIZE
,
typename
InT
,
typename
OutT
,
typename
ScaOp
>
__device__
__forceinline__
void
vectorize_with_alignment
(
const
InT
*
in
,
OutT
*
out
,
int
len
,
int
tid
,
int
stride
,
ScaOp
&&
scalar_op
)
{
using
Vec
=
DefaultVecOp
<
VEC_SIZE
,
InT
,
OutT
,
std
::
decay_t
<
ScaOp
>>
;
vectorize_with_alignment
<
VEC_SIZE
>
(
in
,
out
,
len
,
tid
,
stride
,
Vec
{
scalar_op
},
std
::
forward
<
ScaOp
>
(
scalar_op
));
}
}
// namespace vllm
tests/kernels/quantization/test_int8_quant.py
View file @
b6efafd9
...
...
@@ -11,6 +11,7 @@ from vllm.platforms import current_platform
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
HIDDEN_SIZES
=
[
16
,
67
,
768
,
5137
,
8193
]
# Arbitrary values for testing
HIDDEN_SIZES
+=
list
(
range
(
1024
,
1033
))
# vectorized conversion edge cases
NUM_TOKENS
=
[
1
,
7
,
83
,
4096
]
# Arbitrary values for testing
SEEDS
=
[
0
]
SCALE
=
[
0.1
,
2.1
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment