Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
14d195dc
Commit
14d195dc
authored
Feb 09, 2026
by
yangql1
Browse files
新增awq_marlin gemm的qwen72B的支持
parent
a27f634a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
62 additions
and
15 deletions
+62
-15
vllm/_custom_ops.py
vllm/_custom_ops.py
+32
-0
vllm/envs.py
vllm/envs.py
+6
-0
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+24
-15
No files found.
vllm/_custom_ops.py
View file @
14d195dc
...
...
@@ -420,6 +420,24 @@ def awq_gemm_fake(input: torch.Tensor, weight: torch.Tensor,
splikspacesize
:
int
)
->
torch
.
Tensor
:
return
torch
.
empty
((
m
,
n
),
dtype
=
input
.
dtype
,
device
=
input
.
device
)
def
awq_gemm_marlin_weight_repack
(
weight_trans
:
torch
.
Tensor
,
N
:
int
,
K
:
int
)
->
torch
.
Tensor
:
return
lightop
.
awq_gemm_marlin_weight_repack
(
weight_trans
,
N
,
K
)
def
awq_gemm_marlin_weight_repack_fake
(
weight_trans
:
torch
.
Tensor
,
N
:
int
,
K
:
int
)
->
torch
.
Tensor
:
return
torch
.
empty
((
N
,
K
),
dtype
=
weight_trans
.
dtype
,
device
=
weight_trans
.
device
)
def
gemm_awq_w4a16_marlin
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
zeros_and_scales
:
torch
.
Tensor
,
m
:
int
,
n
:
int
,
k
:
int
)
->
torch
.
Tensor
:
return
lightop
.
gemm_awq_w4a16_marlin
(
input
,
weight
,
zeros_and_scales
)
def
gemm_awq_w4a16_marlin_fake
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
zeros_and_scales
:
torch
.
Tensor
,
m
:
int
,
n
:
int
,
k
:
int
)
->
torch
.
Tensor
:
return
torch
.
empty
((
m
,
n
),
dtype
=
input
.
dtype
,
device
=
input
.
device
)
def
convert_s4
(
qw
:
torch
.
Tensor
,
qz
:
torch
.
Tensor
,
s
:
torch
.
Tensor
,
group_size
:
int
):
...
...
@@ -2299,3 +2317,17 @@ direct_register_custom_op(
mutates_args
=
[],
fake_impl
=
gptq_gemm_fake_
,
)
direct_register_custom_op
(
op_name
=
"awq_gemm_marlin_weight_repack"
,
op_func
=
awq_gemm_marlin_weight_repack
,
mutates_args
=
[],
fake_impl
=
awq_gemm_marlin_weight_repack_fake
,
)
direct_register_custom_op
(
op_name
=
"gemm_awq_w4a16_marlin"
,
op_func
=
gemm_awq_w4a16_marlin
,
mutates_args
=
[],
fake_impl
=
gemm_awq_w4a16_marlin_fake
,
)
\ No newline at end of file
vllm/envs.py
View file @
14d195dc
...
...
@@ -95,6 +95,7 @@ if TYPE_CHECKING:
VLLM_TORCH_PROFILER_WITH_STACK
:
bool
=
True
VLLM_TORCH_PROFILER_WITH_FLOPS
:
bool
=
False
VLLM_USE_TRITON_AWQ
:
bool
=
False
AWQ_GEMM_MARLIN
:
bool
=
False
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
bool
=
False
VLLM_SKIP_P2P_CHECK
:
bool
=
False
VLLM_DISABLED_KERNELS
:
list
[
str
]
=
[]
...
...
@@ -908,6 +909,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRITON_AWQ"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_TRITON_AWQ"
,
"0"
))),
# If set, vLLM will use marlin implementations of AWQ.
"AWQ_GEMM_MARLIN"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"AWQ_GEMM_MARLIN"
,
"0"
))),
# If set, allow loading or unloading lora adapters in runtime,
"VLLM_ALLOW_RUNTIME_LORA_UPDATING"
:
lambda
:
...
...
@@ -1762,6 +1767,7 @@ def compute_hash() -> str:
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH"
,
"VLLM_USE_TRITON_FLASH_ATTN"
,
"VLLM_USE_TRITON_AWQ"
,
"AWQ_GEMM_MARLIN"
,
"VLLM_DP_RANK"
,
"VLLM_DP_SIZE"
,
"VLLM_USE_STANDALONE_COMPILE"
,
...
...
vllm/model_executor/layers/quantization/awq.py
View file @
14d195dc
...
...
@@ -21,6 +21,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
PackedvLLMParameter
)
from
vllm.model_executor.layers.quantization.awq_triton
import
awq_gemm_triton
import
lightop
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
triton_configs_dict
=
{}
...
...
@@ -205,8 +206,9 @@ class AWQLinearMethod(LinearMethodBase):
"""
def
__init__
(
self
,
quant_config
:
AWQConfig
):
if
not
envs
.
AWQ_GEMM_MARLIN
and
not
envs
.
VLLM_USE_TRITON_AWQ
:
self
.
awqsingleton
=
AWQShareWorkSpace
()
self
.
quant_config
=
quant_config
self
.
awqsingleton
=
AWQShareWorkSpace
()
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -303,7 +305,10 @@ class AWQLinearMethod(LinearMethodBase):
sz
=
torch
.
cat
((
sz
,
zeros_and_scalse_pad
),
dim
=
1
).
contiguous
()
qweight_pad
=
torch
.
zeros
(
dim_n
,
int
(
group_size
//
4
),
dtype
=
torch
.
int32
).
cuda
()
_qw
=
torch
.
cat
((
_qw
,
qweight_pad
),
dim
=
1
).
contiguous
()
if
envs
.
AWQ_GEMM_MARLIN
:
_qw
=
torch
.
ops
.
vllm
.
awq_gemm_marlin_weight_repack
(
_qw
,
dim_n
,
dim_k
)
layer
.
qweight
=
torch
.
nn
.
Parameter
(
_qw
,
requires_grad
=
False
)
layer
.
zeros_and_scales
=
torch
.
nn
.
Parameter
(
sz
,
requires_grad
=
False
)
layer
.
qzeros
=
None
...
...
@@ -326,12 +331,13 @@ class AWQLinearMethod(LinearMethodBase):
qzeros
=
layer
.
qzeros
scales
=
layer
.
scales
pack_factor
=
self
.
quant_config
.
pack_factor
out_shape
=
(
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
0
]
*
1
,
))
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
m
=
reshaped_x
.
shape
[
0
]
k
=
reshaped_x
.
shape
[
-
1
]
n
=
qweight
.
shape
[
0
]
n
=
layer
.
output_size
out_shape
=
(
x
.
shape
[:
-
1
]
+
(
n
,
))
if
self
.
use_awq_pad
:
if
k
%
4096
==
0
:
...
...
@@ -346,16 +352,19 @@ class AWQLinearMethod(LinearMethodBase):
out
=
awq_gemm_triton
(
reshaped_x
,
qweight
,
scales
,
qzeros
,
pack_factor
,
best_config
)
out_shape
=
(
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
1
]
*
8
,
))
else
:
out
=
torch
.
ops
.
vllm
.
awq_gemm
(
reshaped_x
,
qweight
,
zeros_and_scales
,
m
,
n
,
k
,
self
.
quant_config
.
group_size
,
padding_group
,
self
.
awqsingleton
.
awqworkshapce
,
self
.
awqsingleton
.
awqworkshapcesize
)
if
envs
.
AWQ_GEMM_MARLIN
:
out
=
torch
.
ops
.
vllm
.
gemm_awq_w4a16_marlin
(
reshaped_x
,
qweight
,
zeros_and_scales
,
m
,
n
,
k
)
else
:
out
=
torch
.
ops
.
vllm
.
awq_gemm
(
reshaped_x
,
qweight
,
zeros_and_scales
,
m
,
n
,
k
,
self
.
quant_config
.
group_size
,
padding_group
,
self
.
awqsingleton
.
awqworkshapce
,
self
.
awqsingleton
.
awqworkshapcesize
)
if
bias
is
not
None
:
out
.
add_
(
bias
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment