Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
07f94463
Unverified
Commit
07f94463
authored
Mar 12, 2025
by
Rex
Committed by
GitHub
Mar 12, 2025
Browse files
Add awq dequantize kernel to sgl with 1x to 3x speedup (#4104)
parent
e0917e6b
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
324 additions
and
0 deletions
+324
-0
sgl-kernel/benchmark/bench_awq_dequant.py
sgl-kernel/benchmark/bench_awq_dequant.py
+118
-0
sgl-kernel/csrc/gemm/awq_kernel.cu
sgl-kernel/csrc/gemm/awq_kernel.cu
+127
-0
sgl-kernel/csrc/torch_extension.cc
sgl-kernel/csrc/torch_extension.cc
+3
-0
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+1
-0
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-0
sgl-kernel/python/sgl_kernel/gemm.py
sgl-kernel/python/sgl_kernel/gemm.py
+6
-0
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-0
sgl-kernel/tests/test_awq_dequant.py
sgl-kernel/tests/test_awq_dequant.py
+67
-0
No files found.
sgl-kernel/benchmark/bench_awq_dequant.py
0 → 100644
View file @
07f94463
import
itertools
from
typing
import
List
,
Tuple
import
torch
import
triton
import
triton.testing
from
sgl_kernel
import
awq_dequantize
from
vllm
import
_custom_ops
as
ops
def
vllm_awq_dequantize
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
ops
.
awq_dequantize
(
qweight
,
scales
,
qzeros
,
0
,
0
,
0
)
def
sglang_awq_dequantize
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
awq_dequantize
(
qweight
,
scales
,
qzeros
)
def
calculate_diff
(
qweight_row
:
int
,
qweight_col
:
int
):
"""Calculate difference between VLLM and SGLang implementations."""
device
=
torch
.
device
(
"cuda"
)
qweight
=
torch
.
randint
(
0
,
torch
.
iinfo
(
torch
.
int32
).
max
,
(
qweight_row
,
qweight_col
),
dtype
=
torch
.
int32
,
device
=
device
,
)
group_size
=
qweight_row
scales_row
=
qweight_row
//
group_size
scales_col
=
qweight_col
*
8
scales
=
torch
.
rand
(
scales_row
,
scales_col
,
dtype
=
torch
.
float16
,
device
=
device
)
qzeros
=
torch
.
randint
(
0
,
torch
.
iinfo
(
torch
.
int32
).
max
,
(
scales_row
,
qweight_col
),
dtype
=
torch
.
int32
,
device
=
device
,
)
vllm_out
=
vllm_awq_dequantize
(
qweight
,
scales
,
qzeros
)
sglang_out
=
sglang_awq_dequantize
(
qweight
,
scales
,
qzeros
)
output_diff
=
torch
.
abs
(
vllm_out
.
float
()
-
sglang_out
.
float
()).
mean
().
item
()
if
torch
.
allclose
(
vllm_out
.
to
(
torch
.
float32
),
sglang_out
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
):
print
(
"✅ All implementations match"
)
else
:
print
(
"❌ Implementations differ"
)
qweight_row_range
=
[
3584
,
18944
,
128
,
256
,
512
,
1024
]
qweight_cols_range
=
[
448
,
576
,
4736
,
16
,
32
,
64
,
128
]
configs
=
list
(
itertools
.
product
(
qweight_row_range
,
qweight_cols_range
))
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"qweight_row"
,
"qweight_col"
],
x_vals
=
configs
,
line_arg
=
"provider"
,
line_vals
=
[
"vllm"
,
"sglang"
],
line_names
=
[
"VLLM"
,
"SGL Kernel"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"awq-dequantize-performance"
,
args
=
{},
)
)
def
benchmark
(
qweight_row
,
qweight_col
,
provider
):
dtype
=
torch
.
float16
device
=
torch
.
device
(
"cuda"
)
qweight
=
torch
.
randint
(
0
,
torch
.
iinfo
(
torch
.
int32
).
max
,
(
qweight_row
,
qweight_col
),
dtype
=
torch
.
int32
,
device
=
device
,
)
group_size
=
qweight_row
scales_row
=
qweight_row
//
group_size
scales_col
=
qweight_col
*
8
scales
=
torch
.
rand
(
scales_row
,
scales_col
,
dtype
=
torch
.
float16
,
device
=
device
)
qzeros
=
torch
.
randint
(
0
,
torch
.
iinfo
(
torch
.
int32
).
max
,
(
scales_row
,
qweight_col
),
dtype
=
torch
.
int32
,
device
=
device
,
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"vllm"
:
fn
=
lambda
:
vllm_awq_dequantize
(
qweight
.
clone
(),
scales
.
clone
(),
qzeros
.
clone
()
)
elif
provider
==
"sglang"
:
fn
=
lambda
:
sglang_awq_dequantize
(
qweight
.
clone
(),
scales
.
clone
(),
qzeros
.
clone
()
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
fn
,
quantiles
=
quantiles
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
if
__name__
==
"__main__"
:
calculate_diff
(
qweight_row
=
3584
,
qweight_col
=
448
)
benchmark
.
run
(
print_data
=
True
)
sgl-kernel/csrc/gemm/awq_kernel.cu
0 → 100644
View file @
07f94463
// Adapted from
// https://github.com/vllm-project/vllm/blob/eb59b5a6cba6727d3727c0372258db9002f687c1/csrc/quantization/awq/gemm_kernels.cu#L350
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp16.h>
#include <torch/all.h>
__device__
uint4
dequantize_s4_to_fp16x2
(
uint32_t
const
&
source
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
uint4
result
;
uint32_t
*
h
=
reinterpret_cast
<
uint32_t
*>
(
&
result
);
uint32_t
const
i4s
=
reinterpret_cast
<
uint32_t
const
&>
(
source
);
// First, we extract the i4s and construct an intermediate fp16 number.
static
constexpr
uint32_t
immLut
=
(
0xf0
&
0xcc
)
|
0xaa
;
static
constexpr
uint32_t
BOTTOM_MASK
=
0x000f000f
;
static
constexpr
uint32_t
TOP_MASK
=
0x00f000f0
;
static
constexpr
uint32_t
I4s_TO_F16s_MAGIC_NUM
=
0x64006400
;
// Note that the entire sequence only requires 1 shift instruction. This is
// thanks to the register packing format and the fact that we force our
// integers to be unsigned, and account for this in the fp16 subtractions. In
// addition, I exploit the fact that sub and fma have the same throughput in
// order to convert elt_23 and elt_67 to fp16 without having to shift them to
// the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
// dependency if we issue immediately before required.
const
uint32_t
top_i4s
=
i4s
>>
8
;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
0
])
:
"r"
(
i4s
),
"n"
(
BOTTOM_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
1
])
:
"r"
(
i4s
),
"n"
(
TOP_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
2
])
:
"r"
(
top_i4s
),
"n"
(
BOTTOM_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
3
])
:
"r"
(
top_i4s
),
"n"
(
TOP_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
// This is the half2 {1024, 1024} represented as an integer.
static
constexpr
uint32_t
FP16_TOP_MAGIC_NUM
=
0x64006400
;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static
constexpr
uint32_t
ONE_SIXTEENTH
=
0x2c002c00
;
// This is the half2 {-64, -64} represented as an integer.
static
constexpr
uint32_t
NEG_64
=
0xd400d400
;
// Finally, we construct the output numbers.
// Convert elt_01
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
h
[
0
])
:
"r"
(
h
[
0
]),
"r"
(
FP16_TOP_MAGIC_NUM
));
// Convert elt_23
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
1
])
:
"r"
(
h
[
1
]),
"r"
(
ONE_SIXTEENTH
),
"r"
(
NEG_64
));
// Convert elt_45
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
h
[
2
])
:
"r"
(
h
[
2
]),
"r"
(
FP16_TOP_MAGIC_NUM
));
// Convert elt_67
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
3
])
:
"r"
(
h
[
3
]),
"r"
(
ONE_SIXTEENTH
),
"r"
(
NEG_64
));
return
result
;
#else
assert
(
false
);
return
{};
#endif
}
__global__
void
__launch_bounds__
(
256
)
dequantize_weights
(
int
*
__restrict__
qweight
,
half
*
__restrict__
scales
,
int
*
__restrict__
qzeros
,
half
*
__restrict__
output
,
int
group_size
,
int
qweight_cols
)
{
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
row
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
uint4
zeros
=
dequantize_s4_to_fp16x2
(
qzeros
[
col
+
(
row
/
group_size
)
*
qweight_cols
]);
uint4
loaded_scale
=
*
(
uint4
*
)(
scales
+
8
*
col
+
(
row
/
group_size
)
*
qweight_cols
*
8
);
uint4
weight_fp16
=
dequantize_s4_to_fp16x2
(
qweight
[
col
+
row
*
qweight_cols
]);
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
x
)
:
"r"
(
weight_fp16
.
x
),
"r"
(
zeros
.
x
));
asm
volatile
(
"mul.rn.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
x
)
:
"r"
(
weight_fp16
.
x
),
"r"
(
loaded_scale
.
x
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
y
)
:
"r"
(
weight_fp16
.
y
),
"r"
(
zeros
.
y
));
asm
volatile
(
"mul.rn.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
y
)
:
"r"
(
weight_fp16
.
y
),
"r"
(
loaded_scale
.
y
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
z
)
:
"r"
(
weight_fp16
.
z
),
"r"
(
zeros
.
z
));
asm
volatile
(
"mul.rn.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
z
)
:
"r"
(
weight_fp16
.
z
),
"r"
(
loaded_scale
.
z
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
w
)
:
"r"
(
weight_fp16
.
w
),
"r"
(
zeros
.
w
));
asm
volatile
(
"mul.rn.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
w
)
:
"r"
(
weight_fp16
.
w
),
"r"
(
loaded_scale
.
w
));
half
*
output_ptr
=
output
+
8
*
col
+
8
*
row
*
qweight_cols
;
*
(
uint4
*
)
output_ptr
=
weight_fp16
;
}
torch
::
Tensor
awq_dequantize
(
torch
::
Tensor
qweight
,
torch
::
Tensor
scales
,
torch
::
Tensor
qzeros
)
{
int
qweight_rows
=
qweight
.
size
(
0
);
int
qweight_cols
=
qweight
.
size
(
1
);
int
group_size
=
qweight_rows
/
scales
.
size
(
0
);
int
x_num_threads
=
16
;
int
y_num_threads
=
16
;
int
x_blocks
=
qweight_cols
/
x_num_threads
;
int
y_blocks
=
qweight_rows
/
y_num_threads
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
qweight
));
auto
output_tensor_options
=
torch
::
TensorOptions
().
dtype
(
scales
.
dtype
()).
device
(
scales
.
device
());
at
::
Tensor
output
=
torch
::
empty
({
qweight_rows
,
qweight_cols
*
8
},
output_tensor_options
);
auto
_qweight
=
reinterpret_cast
<
int
*>
(
qweight
.
data_ptr
<
int
>
());
auto
_scales
=
reinterpret_cast
<
half
*>
(
scales
.
data_ptr
<
at
::
Half
>
());
auto
_zeros
=
reinterpret_cast
<
int
*>
(
qzeros
.
data_ptr
<
int
>
());
auto
_output
=
reinterpret_cast
<
half
*>
(
output
.
data_ptr
<
at
::
Half
>
());
dim3
num_blocks
(
x_blocks
,
y_blocks
);
dim3
threads_per_block
(
x_num_threads
,
y_num_threads
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dequantize_weights
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
_qweight
,
_scales
,
_zeros
,
_output
,
group_size
,
qweight_cols
);
return
output
;
}
sgl-kernel/csrc/torch_extension.cc
View file @
07f94463
...
@@ -75,6 +75,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
...
@@ -75,6 +75,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
/*
* From csrc/gemm
* From csrc/gemm
*/
*/
m
.
def
(
"awq_dequantize(Tensor qweight, Tensor scales, Tensor qzeros) -> Tensor"
);
m
.
impl
(
"awq_dequantize"
,
torch
::
kCUDA
,
&
awq_dequantize
);
m
.
def
(
m
.
def
(
"int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
"int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
"bias) -> Tensor"
);
"bias) -> Tensor"
);
...
...
sgl-kernel/include/sgl_kernel_ops.h
View file @
07f94463
...
@@ -112,6 +112,7 @@ void apply_rope_pos_ids_cos_sin_cache(
...
@@ -112,6 +112,7 @@ void apply_rope_pos_ids_cos_sin_cache(
/*
/*
* From csrc/gemm
* From csrc/gemm
*/
*/
torch
::
Tensor
awq_dequantize
(
torch
::
Tensor
qweight
,
torch
::
Tensor
scales
,
torch
::
Tensor
qzeros
);
torch
::
Tensor
int8_scaled_mm
(
torch
::
Tensor
int8_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
mat_b
,
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
07f94463
...
@@ -23,6 +23,7 @@ from sgl_kernel.elementwise import (
...
@@ -23,6 +23,7 @@ from sgl_kernel.elementwise import (
silu_and_mul
,
silu_and_mul
,
)
)
from
sgl_kernel.gemm
import
(
from
sgl_kernel.gemm
import
(
awq_dequantize
,
bmm_fp8
,
bmm_fp8
,
cublas_grouped_gemm
,
cublas_grouped_gemm
,
fp8_blockwise_scaled_mm
,
fp8_blockwise_scaled_mm
,
...
...
sgl-kernel/python/sgl_kernel/gemm.py
View file @
07f94463
...
@@ -4,6 +4,12 @@ import torch
...
@@ -4,6 +4,12 @@ import torch
from
sgl_kernel.utils
import
_get_cache_buf
,
get_cuda_stream
from
sgl_kernel.utils
import
_get_cache_buf
,
get_cuda_stream
def
awq_dequantize
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
)
->
torch
.
ByteTensor
:
return
torch
.
ops
.
sgl_kernels
.
awq_dequantize
(
qweight
,
scales
,
qzeros
)
def
int8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
def
int8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
return
torch
.
ops
.
sgl_kernel
.
int8_scaled_mm
(
return
torch
.
ops
.
sgl_kernel
.
int8_scaled_mm
(
mat_a
,
mat_a
,
...
...
sgl-kernel/setup.py
View file @
07f94463
...
@@ -150,6 +150,7 @@ sources = [
...
@@ -150,6 +150,7 @@ sources = [
"csrc/elementwise/rope.cu"
,
"csrc/elementwise/rope.cu"
,
"csrc/gemm/bmm_fp8.cu"
,
"csrc/gemm/bmm_fp8.cu"
,
"csrc/gemm/cublas_grouped_gemm.cu"
,
"csrc/gemm/cublas_grouped_gemm.cu"
,
"csrc/gemm/awq_kernel.cu"
,
"csrc/gemm/fp8_gemm_kernel.cu"
,
"csrc/gemm/fp8_gemm_kernel.cu"
,
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
,
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
,
"csrc/gemm/int8_gemm_kernel.cu"
,
"csrc/gemm/int8_gemm_kernel.cu"
,
...
...
sgl-kernel/tests/test_awq_dequant.py
0 → 100644
View file @
07f94463
import
itertools
from
typing
import
Optional
,
Tuple
import
pytest
import
torch
from
sgl_kernel
import
awq_dequantize
from
vllm
import
_custom_ops
as
ops
def
vllm_awq_dequantize
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
ops
.
awq_dequantize
(
qweight
,
scales
,
qzeros
,
0
,
0
,
0
)
def
sglang_awq_dequantize
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
awq_dequantize
(
qweight
,
scales
,
qzeros
)
@
pytest
.
mark
.
parametrize
(
"qweight_row,qweight_col"
,
list
(
itertools
.
product
(
[
3584
,
18944
,
128
,
256
,
512
,
1024
],
[
448
,
576
,
4736
,
16
,
32
,
64
,
128
]
)
),
)
def
test_awq_dequant_compare_implementations
(
qweight_row
:
int
,
qweight_col
:
int
,
):
device
=
torch
.
device
(
"cuda"
)
qweight
=
torch
.
randint
(
0
,
torch
.
iinfo
(
torch
.
int32
).
max
,
(
qweight_row
,
qweight_col
),
dtype
=
torch
.
int32
,
device
=
device
,
)
group_size
=
qweight_row
scales_row
=
qweight_row
//
group_size
scales_col
=
qweight_col
*
8
scales
=
torch
.
rand
(
scales_row
,
scales_col
,
dtype
=
torch
.
float16
,
device
=
device
)
qzeros
=
torch
.
randint
(
0
,
torch
.
iinfo
(
torch
.
int32
).
max
,
(
scales_row
,
qweight_col
),
dtype
=
torch
.
int32
,
device
=
device
,
)
# Run both implementations
vllm_out
=
vllm_awq_dequantize
(
qweight
,
scales
,
qzeros
)
sglang_out
=
sglang_awq_dequantize
(
qweight
,
scales
,
qzeros
)
# Compare results
torch
.
testing
.
assert_close
(
vllm_out
.
to
(
torch
.
float32
),
sglang_out
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
)
if
__name__
==
"__main__"
:
# Run the specific test function directly
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment