Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3eb4a800
Unverified
Commit
3eb4a800
authored
Jun 18, 2025
by
AniZpZ
Committed by
GitHub
Jun 17, 2025
Browse files
Fix AWQ Dequant and Weight Loading of deepseek v2 (#6842)
parent
e7261315
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
11 deletions
+18
-11
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+7
-1
sgl-kernel/csrc/gemm/awq_kernel.cu
sgl-kernel/csrc/gemm/awq_kernel.cu
+9
-7
sgl-kernel/tests/test_awq_dequant.py
sgl-kernel/tests/test_awq_dequant.py
+2
-3
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
3eb4a800
...
...
@@ -2137,8 +2137,14 @@ class DeepseekV2ForCausalLM(nn.Module):
):
q_a_proj_weight
=
cached_a_proj
[
q_a_proj_name
]
kv_a_proj_weight
=
cached_a_proj
[
kv_a_proj_name
]
cat_dim
=
0
if
(
self
.
quant_config
.
get_name
()
==
"awq"
or
self
.
quant_config
.
get_name
()
==
"moe_wna16"
):
cat_dim
=
1
fused_weight
=
torch
.
cat
(
[
q_a_proj_weight
,
kv_a_proj_weight
],
dim
=
0
[
q_a_proj_weight
,
kv_a_proj_weight
],
dim
=
cat_dim
)
param_name
=
(
name
.
replace
(
"q_a_proj"
,
"fused_qkv_a_proj_with_mqa"
)
...
...
sgl-kernel/csrc/gemm/awq_kernel.cu
View file @
3eb4a800
...
...
@@ -130,10 +130,12 @@ __global__ void __launch_bounds__(256) dequantize_weights(
int
*
__restrict__
qzeros
,
OutputT
*
__restrict__
output
,
int
group_size
,
int
qweight_cols
)
{
int
qweight_cols
,
int
qweight_rows
)
{
#if CUDA_VERSION >= 12000
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
row
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
col
>=
qweight_cols
||
row
>=
qweight_rows
)
return
;
int
group_idx
=
row
/
group_size
;
int
scale_offset
=
8
*
col
+
group_idx
*
qweight_cols
*
8
;
...
...
@@ -188,8 +190,8 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch:
int
x_num_threads
=
16
;
int
y_num_threads
=
16
;
int
x_blocks
=
qweight_cols
/
x_num_threads
;
int
y_blocks
=
qweight_rows
/
y_num_threads
;
int
x_blocks
=
(
qweight_cols
+
x_num_threads
-
1
)
/
x_num_threads
;
int
y_blocks
=
(
qweight_rows
+
y_num_threads
-
1
)
/
y_num_threads
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
qweight
));
...
...
@@ -206,13 +208,13 @@ torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch:
if
(
scales
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
auto
_scales
=
reinterpret_cast
<
half
*>
(
scales
.
data_ptr
<
at
::
Half
>
());
auto
_output
=
reinterpret_cast
<
half
*>
(
output
.
data_ptr
<
at
::
Half
>
());
dequantize_weights
<
half
>
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
_qweight
,
_scales
,
_zeros
,
_output
,
group_size
,
qweight_cols
);
dequantize_weights
<
half
>
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
_qweight
,
_scales
,
_zeros
,
_output
,
group_size
,
qweight_cols
,
qweight_rows
);
}
else
{
auto
_scales
=
reinterpret_cast
<
__nv_bfloat16
*>
(
scales
.
data_ptr
<
at
::
BFloat16
>
());
auto
_output
=
reinterpret_cast
<
__nv_bfloat16
*>
(
output
.
data_ptr
<
at
::
BFloat16
>
());
dequantize_weights
<
__nv_bfloat16
>
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
_qweight
,
_scales
,
_zeros
,
_output
,
group_size
,
qweight_cols
);
dequantize_weights
<
__nv_bfloat16
>
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
_qweight
,
_scales
,
_zeros
,
_output
,
group_size
,
qweight_cols
,
qweight_rows
);
}
return
output
;
...
...
sgl-kernel/tests/test_awq_dequant.py
View file @
3eb4a800
...
...
@@ -67,8 +67,8 @@ def sglang_awq_dequantize(
"qweight_row,qweight_col,is_bf16_act"
,
list
(
itertools
.
product
(
[
3584
,
18944
,
128
,
256
,
512
,
1024
],
[
448
,
576
,
4736
,
16
,
32
,
64
,
128
],
[
3584
,
18944
,
128
,
256
,
512
,
1024
,
1536
],
[
448
,
576
,
4736
,
16
,
32
,
64
,
128
,
72
],
[
True
,
False
],
)
),
...
...
@@ -77,7 +77,6 @@ def test_awq_dequant_compare_implementations(
qweight_row
:
int
,
qweight_col
:
int
,
is_bf16_act
:
bool
):
device
=
torch
.
device
(
"cuda"
)
qweight
=
torch
.
randint
(
0
,
torch
.
iinfo
(
torch
.
int32
).
max
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment