Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
85b2e057
Unverified
Commit
85b2e057
authored
Jan 13, 2025
by
Ke Bao
Committed by
GitHub
Jan 13, 2025
Browse files
Add int8 quant kernel (#2848)
parent
a879c2fb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
147 additions
and
0 deletions
+147
-0
benchmark/kernels/quantization/bench_int8_quant.py
benchmark/kernels/quantization/bench_int8_quant.py
+94
-0
python/sglang/srt/layers/quantization/int8_kernel.py
python/sglang/srt/layers/quantization/int8_kernel.py
+53
-0
No files found.
benchmark/kernels/quantization/bench_int8_quant.py
0 → 100644
View file @
85b2e057
import
argparse
import
torch
import
triton
from
vllm._custom_ops
import
scaled_int8_quant
as
vllm_scaled_int8_quant
from
sglang.srt.layers.quantization.int8_kernel
import
per_token_quant_int8
@
torch
.
compile
(
backend
=
"inductor"
)
def
torch_int8_quant
(
x
):
int8_max
=
torch
.
iinfo
(
torch
.
int8
).
max
abs_max
=
x
.
abs
().
max
(
dim
=-
1
,
keepdim
=
True
).
values
scales
=
abs_max
.
to
(
torch
.
float32
)
/
float
(
int8_max
)
q_x
=
(
x
/
scales
).
round
().
to
(
torch
.
int8
)
return
q_x
,
scales
def
_test_accuracy_once
(
M
,
K
,
input_dtype
,
device
):
x
=
torch
.
randn
(
M
,
K
,
dtype
=
input_dtype
,
device
=
device
)
*
5000
out
,
scales
,
_
=
vllm_scaled_int8_quant
(
x
,
symmetric
=
True
)
out1
,
scales1
=
per_token_quant_int8
(
x
)
out2
,
scales2
=
torch_int8_quant
(
x
)
torch
.
testing
.
assert_close
(
out
,
out2
,
atol
=
1
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
out
,
out1
,
atol
=
1
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
scales
,
scales2
)
torch
.
testing
.
assert_close
(
scales1
,
scales2
)
print
(
f
"M:
{
M
}
, K:
{
K
}
, type:
{
input_dtype
}
OK"
)
def
test_accuracy
():
Ms
=
[
1
,
13
,
128
,
1024
,
2048
,
4096
]
Ks
=
[
512
,
1024
,
2048
,
8192
]
input_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
]
for
M
in
Ms
:
for
K
in
Ks
:
for
input_dtype
in
input_dtypes
:
_test_accuracy_once
(
M
,
K
,
input_dtype
,
"cuda"
)
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"vllm op"
,
"triton"
,
"torch.compile"
],
line_names
=
[
"vllm op"
,
"triton"
,
"torch.compile"
],
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"ms"
,
plot_name
=
"int8 per token quant"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
):
M
,
K
=
batch_size
,
16384
x
=
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
*
1000
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"vllm op"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
vllm_scaled_int8_quant
(
x
,
symmetric
=
True
),
quantiles
=
quantiles
,
)
if
provider
==
"triton"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
per_token_quant_int8
(
x
),
quantiles
=
quantiles
,
)
if
provider
==
"torch.compile"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch_int8_quant
(
x
),
quantiles
=
quantiles
,
)
return
ms
,
min_ms
,
max_ms
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--save_path"
,
type
=
str
,
default
=
"./bench_int8_quant_res"
,
help
=
"Path to save int8 quant benchmark results"
,
)
args
=
parser
.
parse_args
()
test_accuracy
()
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
args
.
save_path
)
python/sglang/srt/layers/quantization/int8_kernel.py
0 → 100644
View file @
85b2e057
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
_per_token_quant_int8
(
x_ptr
,
xq_ptr
,
scale_ptr
,
stride_x
,
stride_xq
,
N
,
BLOCK
:
tl
.
constexpr
,
):
# Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
row_id
=
tl
.
program_id
(
0
)
cols
=
tl
.
arange
(
0
,
BLOCK
)
mask
=
cols
<
N
x
=
tl
.
load
(
x_ptr
+
row_id
*
stride_x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
x
)),
1e-10
)
scale_x
=
absmax
/
127
x_q
=
tl
.
extra
.
cuda
.
libdevice
.
round
(
x
/
scale_x
).
to
(
tl
.
int8
)
tl
.
store
(
xq_ptr
+
row_id
*
stride_xq
+
cols
,
x_q
,
mask
=
mask
)
tl
.
store
(
scale_ptr
+
row_id
,
scale_x
)
def
per_token_quant_int8
(
x
):
M
=
x
.
numel
()
//
x
.
shape
[
-
1
]
N
=
x
.
shape
[
-
1
]
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
torch
.
int8
)
scales
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
1
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
assert
x
.
is_contiguous
()
_per_token_quant_int8
[(
M
,)](
x
,
x_q
,
scales
,
stride_x
=
x
.
stride
(
-
2
),
stride_xq
=
x_q
.
stride
(
-
2
),
N
=
N
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
x_q
,
scales
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment