Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
85b2e057
"docs/vscode:/vscode.git/clone" did not exist on "f082491b29ee115d822181eaf1ea5618570242b7"
Unverified
Commit
85b2e057
authored
Jan 13, 2025
by
Ke Bao
Committed by
GitHub
Jan 13, 2025
Browse files
Add int8 quant kernel (#2848)
parent
a879c2fb
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
147 additions
and
0 deletions
+147
-0
benchmark/kernels/quantization/bench_int8_quant.py
benchmark/kernels/quantization/bench_int8_quant.py
+94
-0
python/sglang/srt/layers/quantization/int8_kernel.py
python/sglang/srt/layers/quantization/int8_kernel.py
+53
-0
No files found.
benchmark/kernels/quantization/bench_int8_quant.py
0 → 100644
View file @
85b2e057
import
argparse
import
torch
import
triton
from
vllm._custom_ops
import
scaled_int8_quant
as
vllm_scaled_int8_quant
from
sglang.srt.layers.quantization.int8_kernel
import
per_token_quant_int8
@
torch
.
compile
(
backend
=
"inductor"
)
def
torch_int8_quant
(
x
):
int8_max
=
torch
.
iinfo
(
torch
.
int8
).
max
abs_max
=
x
.
abs
().
max
(
dim
=-
1
,
keepdim
=
True
).
values
scales
=
abs_max
.
to
(
torch
.
float32
)
/
float
(
int8_max
)
q_x
=
(
x
/
scales
).
round
().
to
(
torch
.
int8
)
return
q_x
,
scales
def
_test_accuracy_once
(
M
,
K
,
input_dtype
,
device
):
x
=
torch
.
randn
(
M
,
K
,
dtype
=
input_dtype
,
device
=
device
)
*
5000
out
,
scales
,
_
=
vllm_scaled_int8_quant
(
x
,
symmetric
=
True
)
out1
,
scales1
=
per_token_quant_int8
(
x
)
out2
,
scales2
=
torch_int8_quant
(
x
)
torch
.
testing
.
assert_close
(
out
,
out2
,
atol
=
1
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
out
,
out1
,
atol
=
1
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
scales
,
scales2
)
torch
.
testing
.
assert_close
(
scales1
,
scales2
)
print
(
f
"M:
{
M
}
, K:
{
K
}
, type:
{
input_dtype
}
OK"
)
def
test_accuracy
():
Ms
=
[
1
,
13
,
128
,
1024
,
2048
,
4096
]
Ks
=
[
512
,
1024
,
2048
,
8192
]
input_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
]
for
M
in
Ms
:
for
K
in
Ks
:
for
input_dtype
in
input_dtypes
:
_test_accuracy_once
(
M
,
K
,
input_dtype
,
"cuda"
)
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"vllm op"
,
"triton"
,
"torch.compile"
],
line_names
=
[
"vllm op"
,
"triton"
,
"torch.compile"
],
styles
=
[(
"blue"
,
"-"
),
(
"orange"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"ms"
,
plot_name
=
"int8 per token quant"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
):
M
,
K
=
batch_size
,
16384
x
=
torch
.
randn
(
M
,
K
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
*
1000
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"vllm op"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
vllm_scaled_int8_quant
(
x
,
symmetric
=
True
),
quantiles
=
quantiles
,
)
if
provider
==
"triton"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
per_token_quant_int8
(
x
),
quantiles
=
quantiles
,
)
if
provider
==
"torch.compile"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
torch_int8_quant
(
x
),
quantiles
=
quantiles
,
)
return
ms
,
min_ms
,
max_ms
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--save_path"
,
type
=
str
,
default
=
"./bench_int8_quant_res"
,
help
=
"Path to save int8 quant benchmark results"
,
)
args
=
parser
.
parse_args
()
test_accuracy
()
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
args
.
save_path
)
python/sglang/srt/layers/quantization/int8_kernel.py
0 → 100644
View file @
85b2e057
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
_per_token_quant_int8
(
x_ptr
,
xq_ptr
,
scale_ptr
,
stride_x
,
stride_xq
,
N
,
BLOCK
:
tl
.
constexpr
,
):
# Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
row_id
=
tl
.
program_id
(
0
)
cols
=
tl
.
arange
(
0
,
BLOCK
)
mask
=
cols
<
N
x
=
tl
.
load
(
x_ptr
+
row_id
*
stride_x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
x
)),
1e-10
)
scale_x
=
absmax
/
127
x_q
=
tl
.
extra
.
cuda
.
libdevice
.
round
(
x
/
scale_x
).
to
(
tl
.
int8
)
tl
.
store
(
xq_ptr
+
row_id
*
stride_xq
+
cols
,
x_q
,
mask
=
mask
)
tl
.
store
(
scale_ptr
+
row_id
,
scale_x
)
def
per_token_quant_int8
(
x
):
M
=
x
.
numel
()
//
x
.
shape
[
-
1
]
N
=
x
.
shape
[
-
1
]
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
torch
.
int8
)
scales
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
1
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
assert
x
.
is_contiguous
()
_per_token_quant_int8
[(
M
,)](
x
,
x_q
,
scales
,
stride_x
=
x
.
stride
(
-
2
),
stride_xq
=
x_q
.
stride
(
-
2
),
N
=
N
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
x_q
,
scales
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment