Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
321ab756
Unverified
Commit
321ab756
authored
Mar 22, 2025
by
AniZpZ
Committed by
GitHub
Mar 22, 2025
Browse files
[1/3] fix dsv3 awq issue (#4556)
Co-authored-by:
leoneo
<
1320612015@qq.com
>
parent
38f25e87
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
182 additions
and
30 deletions
+182
-30
sgl-kernel/csrc/gemm/awq_kernel.cu
sgl-kernel/csrc/gemm/awq_kernel.cu
+110
-23
sgl-kernel/tests/test_awq_dequant.py
sgl-kernel/tests/test_awq_dequant.py
+72
-7
No files found.
sgl-kernel/csrc/gemm/awq_kernel.cu
View file @
321ab756
...
@@ -3,6 +3,16 @@
...
@@ -3,6 +3,16 @@
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <torch/all.h>
#include <torch/all.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
template
<
int
lut
>
__device__
inline
int
lop3
(
int
a
,
int
b
,
int
c
)
{
int
res
;
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
res
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
),
"n"
(
lut
));
return
res
;
}
__device__
uint4
dequantize_s4_to_fp16x2
(
uint32_t
const
&
source
)
{
__device__
uint4
dequantize_s4_to_fp16x2
(
uint32_t
const
&
source
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
...
@@ -68,32 +78,102 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
...
@@ -68,32 +78,102 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
#endif
#endif
}
}
__device__
uint4
dequantize_s4_to_bf16x2
(
uint32_t
const
&
source
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
uint4
result
;
uint32_t
*
h
=
reinterpret_cast
<
uint32_t
*>
(
&
result
);
uint32_t
const
i4s
=
source
;
// Define masks and constants
static
constexpr
uint32_t
MASK
=
0x000f000f
;
static
constexpr
uint32_t
EX
=
0x43004300
;
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
static
constexpr
uint32_t
ADD
=
0xC300C300
;
int
lo0
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
i4s
,
MASK
,
EX
);
int
hi0
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
i4s
>>
4
,
MASK
,
EX
);
int
lo1
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
i4s
>>
8
,
MASK
,
EX
);
int
hi1
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
i4s
>>
12
,
MASK
,
EX
);
nv_bfloat162
*
res
=
reinterpret_cast
<
nv_bfloat162
*>
(
h
);
res
[
0
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
lo0
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
res
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
hi0
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
res
[
2
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
lo1
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
res
[
3
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
hi1
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
return
result
;
#else
assert
(
false
);
return
{};
#endif
}
template
<
typename
OutputT
>
__global__
void
__launch_bounds__
(
256
)
dequantize_weights
(
__global__
void
__launch_bounds__
(
256
)
dequantize_weights
(
int
*
__restrict__
qweight
,
int
*
__restrict__
qweight
,
half
*
__restrict__
scales
,
OutputT
*
__restrict__
scales
,
int
*
__restrict__
qzeros
,
int
*
__restrict__
qzeros
,
half
*
__restrict__
output
,
OutputT
*
__restrict__
output
,
int
group_size
,
int
group_size
,
int
qweight_cols
)
{
int
qweight_cols
)
{
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
row
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
row
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
uint4
zeros
=
dequantize_s4_to_fp16x2
(
qzeros
[
col
+
(
row
/
group_size
)
*
qweight_cols
]);
int
group_idx
=
row
/
group_size
;
uint4
loaded_scale
=
*
(
uint4
*
)(
scales
+
8
*
col
+
(
row
/
group_size
)
*
qweight_cols
*
8
);
int
scale_offset
=
8
*
col
+
group_idx
*
qweight_cols
*
8
;
uint4
loaded_scale
=
*
(
uint4
*
)(
scales
+
scale_offset
);
uint4
weight_fp16
=
dequantize_s4_to_fp16x2
(
qweight
[
col
+
row
*
qweight_cols
]);
// Handle different data types
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
x
)
:
"r"
(
weight_fp16
.
x
),
"r"
(
zeros
.
x
));
if
constexpr
(
std
::
is_same
<
OutputT
,
half
>::
value
)
{
asm
volatile
(
"mul.rn.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
x
)
:
"r"
(
weight_fp16
.
x
),
"r"
(
loaded_scale
.
x
));
// FP16 path
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
y
)
:
"r"
(
weight_fp16
.
y
),
"r"
(
zeros
.
y
));
uint4
zeros
=
dequantize_s4_to_fp16x2
(
qzeros
[
col
+
group_idx
*
qweight_cols
]);
asm
volatile
(
"mul.rn.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
y
)
:
"r"
(
weight_fp16
.
y
),
"r"
(
loaded_scale
.
y
));
uint4
weight_fp16
=
dequantize_s4_to_fp16x2
(
qweight
[
col
+
row
*
qweight_cols
]);
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
z
)
:
"r"
(
weight_fp16
.
z
),
"r"
(
zeros
.
z
));
asm
volatile
(
"mul.rn.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
z
)
:
"r"
(
weight_fp16
.
z
),
"r"
(
loaded_scale
.
z
));
// Use PTX assembly for FP16 operations
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
w
)
:
"r"
(
weight_fp16
.
w
),
"r"
(
zeros
.
w
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
x
)
:
"r"
(
weight_fp16
.
x
),
"r"
(
zeros
.
x
));
asm
volatile
(
"mul.rn.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
w
)
:
"r"
(
weight_fp16
.
w
),
"r"
(
loaded_scale
.
w
));
asm
volatile
(
"mul.rn.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
x
)
:
"r"
(
weight_fp16
.
x
),
"r"
(
loaded_scale
.
x
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
y
)
:
"r"
(
weight_fp16
.
y
),
"r"
(
zeros
.
y
));
half
*
output_ptr
=
output
+
8
*
col
+
8
*
row
*
qweight_cols
;
asm
volatile
(
"mul.rn.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
y
)
:
"r"
(
weight_fp16
.
y
),
"r"
(
loaded_scale
.
y
));
*
(
uint4
*
)
output_ptr
=
weight_fp16
;
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
z
)
:
"r"
(
weight_fp16
.
z
),
"r"
(
zeros
.
z
));
asm
volatile
(
"mul.rn.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
z
)
:
"r"
(
weight_fp16
.
z
),
"r"
(
loaded_scale
.
z
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
w
)
:
"r"
(
weight_fp16
.
w
),
"r"
(
zeros
.
w
));
asm
volatile
(
"mul.rn.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
weight_fp16
.
w
)
:
"r"
(
weight_fp16
.
w
),
"r"
(
loaded_scale
.
w
));
OutputT
*
output_ptr
=
output
+
8
*
col
+
8
*
row
*
qweight_cols
;
*
(
uint4
*
)
output_ptr
=
weight_fp16
;
}
else
if
constexpr
(
std
::
is_same
<
OutputT
,
__nv_bfloat16
>::
value
)
{
uint4
weight_raw
=
dequantize_s4_to_bf16x2
(
qweight
[
col
+
row
*
qweight_cols
]);
uint4
zero_raw
=
dequantize_s4_to_bf16x2
(
qzeros
[
col
+
group_idx
*
qweight_cols
]);
uint4
scale_raw
=
*
reinterpret_cast
<
uint4
*>
(
scales
+
scale_offset
);
// Vectorized processing (each uint4 contains 4 nv_bfloat162)
nv_bfloat162
*
weight_vec
=
reinterpret_cast
<
nv_bfloat162
*>
(
&
weight_raw
);
nv_bfloat162
*
zero_vec
=
reinterpret_cast
<
nv_bfloat162
*>
(
&
zero_raw
);
nv_bfloat162
*
scale_vec
=
reinterpret_cast
<
nv_bfloat162
*>
(
&
scale_raw
);
// Single instruction dual-channel operation
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
// uint4 = 4 * nv_bfloat162
weight_vec
[
i
]
=
__hmul2
(
__hsub2
(
weight_vec
[
i
],
zero_vec
[
i
]),
scale_vec
[
i
]);
}
// Directly store to OutputT array (guaranteed contiguous memory)
OutputT
*
output_ptr
=
output
+
8
*
col
+
row
*
qweight_cols
*
8
;
static_assert
(
sizeof
(
uint4
)
==
8
*
sizeof
(
OutputT
),
"Memory layout mismatch"
);
*
reinterpret_cast
<
uint4
*>
(
output_ptr
)
=
weight_raw
;
}
}
}
torch
::
Tensor
awq_dequantize
(
torch
::
Tensor
qweight
,
torch
::
Tensor
scales
,
torch
::
Tensor
qzeros
)
{
torch
::
Tensor
awq_dequantize
(
torch
::
Tensor
qweight
,
torch
::
Tensor
scales
,
torch
::
Tensor
qzeros
)
{
...
@@ -112,16 +192,23 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch:
...
@@ -112,16 +192,23 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch:
at
::
Tensor
output
=
torch
::
empty
({
qweight_rows
,
qweight_cols
*
8
},
output_tensor_options
);
at
::
Tensor
output
=
torch
::
empty
({
qweight_rows
,
qweight_cols
*
8
},
output_tensor_options
);
auto
_qweight
=
reinterpret_cast
<
int
*>
(
qweight
.
data_ptr
<
int
>
());
auto
_qweight
=
reinterpret_cast
<
int
*>
(
qweight
.
data_ptr
<
int
>
());
auto
_scales
=
reinterpret_cast
<
half
*>
(
scales
.
data_ptr
<
at
::
Half
>
());
auto
_zeros
=
reinterpret_cast
<
int
*>
(
qzeros
.
data_ptr
<
int
>
());
auto
_zeros
=
reinterpret_cast
<
int
*>
(
qzeros
.
data_ptr
<
int
>
());
auto
_output
=
reinterpret_cast
<
half
*>
(
output
.
data_ptr
<
at
::
Half
>
());
dim3
num_blocks
(
x_blocks
,
y_blocks
);
dim3
num_blocks
(
x_blocks
,
y_blocks
);
dim3
threads_per_block
(
x_num_threads
,
y_num_threads
);
dim3
threads_per_block
(
x_num_threads
,
y_num_threads
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dequantize_weights
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
_qweight
,
_scales
,
_zeros
,
_output
,
group_size
,
qweight_cols
);
if
(
scales
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
auto
_scales
=
reinterpret_cast
<
half
*>
(
scales
.
data_ptr
<
at
::
Half
>
());
auto
_output
=
reinterpret_cast
<
half
*>
(
output
.
data_ptr
<
at
::
Half
>
());
dequantize_weights
<
half
>
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
_qweight
,
_scales
,
_zeros
,
_output
,
group_size
,
qweight_cols
);
}
else
{
auto
_scales
=
reinterpret_cast
<
__nv_bfloat16
*>
(
scales
.
data_ptr
<
at
::
BFloat16
>
());
auto
_output
=
reinterpret_cast
<
__nv_bfloat16
*>
(
output
.
data_ptr
<
at
::
BFloat16
>
());
dequantize_weights
<
__nv_bfloat16
>
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
_qweight
,
_scales
,
_zeros
,
_output
,
group_size
,
qweight_cols
);
}
return
output
;
return
output
;
}
}
sgl-kernel/tests/test_awq_dequant.py
View file @
321ab756
...
@@ -7,6 +7,57 @@ from sgl_kernel import awq_dequantize
...
@@ -7,6 +7,57 @@ from sgl_kernel import awq_dequantize
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
def
reverse_awq_order
(
t
:
torch
.
Tensor
):
bits
=
4
AWQ_REVERSE_ORDER
=
[
0
,
4
,
1
,
5
,
2
,
6
,
3
,
7
]
reverse_order_tensor
=
torch
.
arange
(
t
.
shape
[
-
1
],
dtype
=
torch
.
int32
,
device
=
t
.
device
,
)
reverse_order_tensor
=
reverse_order_tensor
.
view
(
-
1
,
32
//
bits
)
reverse_order_tensor
=
reverse_order_tensor
[:,
AWQ_REVERSE_ORDER
]
reverse_order_tensor
=
reverse_order_tensor
.
view
(
-
1
)
t
=
t
[:,
reverse_order_tensor
]
&
0xF
return
t
# qweights - [R , C // 8], int32
# scales - [R // G, C ], float16
# zeros - [R // G, C // 8], int32
def
awq_dequantize_torch
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
,
group_size
:
int
)
->
torch
.
Tensor
:
if
group_size
==
-
1
:
group_size
=
qweight
.
shape
[
0
]
bits
=
4
shifts
=
torch
.
arange
(
0
,
32
,
bits
,
device
=
qzeros
.
device
)
iweights
=
torch
.
bitwise_right_shift
(
qweight
[:,
:,
None
],
shifts
[
None
,
None
,
:]).
to
(
torch
.
int8
)
iweights
=
iweights
.
view
(
iweights
.
shape
[
0
],
-
1
)
zeros
=
torch
.
bitwise_right_shift
(
qzeros
[:,
:,
None
],
shifts
[
None
,
None
,
:]).
to
(
torch
.
int8
)
zeros
=
zeros
.
view
(
qzeros
.
shape
[
0
],
-
1
)
zeros
=
reverse_awq_order
(
zeros
)
iweights
=
reverse_awq_order
(
iweights
)
iweights
=
torch
.
bitwise_and
(
iweights
,
(
2
**
bits
)
-
1
)
zeros
=
torch
.
bitwise_and
(
zeros
,
(
2
**
bits
)
-
1
)
scales
=
scales
.
repeat_interleave
(
group_size
,
dim
=
0
)
zeros
=
zeros
.
repeat_interleave
(
group_size
,
dim
=
0
)
return
(
iweights
-
zeros
)
*
scales
def
vllm_awq_dequantize
(
def
vllm_awq_dequantize
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -20,16 +71,17 @@ def sglang_awq_dequantize(
...
@@ -20,16 +71,17 @@ def sglang_awq_dequantize(
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"qweight_row,qweight_col"
,
"qweight_row,qweight_col
,is_bf16_act
"
,
list
(
list
(
itertools
.
product
(
itertools
.
product
(
[
3584
,
18944
,
128
,
256
,
512
,
1024
],
[
448
,
576
,
4736
,
16
,
32
,
64
,
128
]
[
3584
,
18944
,
128
,
256
,
512
,
1024
],
[
448
,
576
,
4736
,
16
,
32
,
64
,
128
],
[
True
,
False
],
)
)
),
),
)
)
def
test_awq_dequant_compare_implementations
(
def
test_awq_dequant_compare_implementations
(
qweight_row
:
int
,
qweight_row
:
int
,
qweight_col
:
int
,
is_bf16_act
:
bool
qweight_col
:
int
,
):
):
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
...
@@ -43,7 +95,12 @@ def test_awq_dequant_compare_implementations(
...
@@ -43,7 +95,12 @@ def test_awq_dequant_compare_implementations(
group_size
=
qweight_row
group_size
=
qweight_row
scales_row
=
qweight_row
//
group_size
scales_row
=
qweight_row
//
group_size
scales_col
=
qweight_col
*
8
scales_col
=
qweight_col
*
8
scales
=
torch
.
rand
(
scales_row
,
scales_col
,
dtype
=
torch
.
float16
,
device
=
device
)
if
is_bf16_act
:
scales
=
torch
.
rand
(
scales_row
,
scales_col
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
else
:
scales
=
torch
.
rand
(
scales_row
,
scales_col
,
dtype
=
torch
.
float16
,
device
=
device
)
qzeros
=
torch
.
randint
(
qzeros
=
torch
.
randint
(
0
,
0
,
torch
.
iinfo
(
torch
.
int32
).
max
,
torch
.
iinfo
(
torch
.
int32
).
max
,
...
@@ -53,13 +110,21 @@ def test_awq_dequant_compare_implementations(
...
@@ -53,13 +110,21 @@ def test_awq_dequant_compare_implementations(
)
)
# Run both implementations
# Run both implementations
vllm_out
=
vllm_awq_dequantize
(
qweight
,
scales
,
qzeros
)
vllm_out
=
vllm_awq_dequantize
(
qweight
,
scales
.
to
(
torch
.
float16
),
qzeros
)
torch_out
=
awq_dequantize_torch
(
qweight
,
scales
,
qzeros
,
group_size
)
sglang_out
=
sglang_awq_dequantize
(
qweight
,
scales
,
qzeros
)
sglang_out
=
sglang_awq_dequantize
(
qweight
,
scales
,
qzeros
)
# Compare results
# Compare results
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
vllm
_out
.
to
(
torch
.
float32
),
sglang_out
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
torch
_out
.
to
(
torch
.
float32
),
sglang_out
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
)
)
if
not
is_bf16_act
:
torch
.
testing
.
assert_close
(
vllm_out
.
to
(
torch
.
float32
),
sglang_out
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
,
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment