Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5f5ddc3d
Commit
5f5ddc3d
authored
Aug 04, 2024
by
gaoqiong
Browse files
add llama model awq support
parent
bdac8f06
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
110 additions
and
15 deletions
+110
-15
vllm/config.py
vllm/config.py
+1
-1
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+54
-13
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+1
-1
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+54
-0
No files found.
vllm/config.py
View file @
5f5ddc3d
...
...
@@ -172,7 +172,7 @@ class ModelConfig:
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
rocm_supported_quantization
=
[
"gptq"
,
"squeezellm"
]
rocm_supported_quantization
=
[
"gptq"
,
"squeezellm"
,
"awq"
]
if
self
.
quantization
is
not
None
:
self
.
quantization
=
self
.
quantization
.
lower
()
...
...
vllm/model_executor/layers/quantization/awq.py
View file @
5f5ddc3d
...
...
@@ -8,6 +8,14 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
try
:
from
lmslim
import
quant_ops
as
_ops
except
Exception
:
print
(
"INFO:you need install lmslim if you want infer awq model.
\n
"
)
class
AWQShareWorkSpace
():
awqworkshapcesize
=
2
<<
29
#
awqworkshapce
=
torch
.
zeros
(
awqworkshapcesize
//
2
+
1
,
dtype
=
torch
.
float16
).
cuda
()
class
AWQConfig
(
QuantizationConfig
):
...
...
@@ -142,6 +150,19 @@ class AWQLinearMethod(LinearMethodBase):
"input_dim"
:
0
,
"output_dim"
:
1
,
})
zeros_and_scales
=
Parameter
(
torch
.
empty
(
(
input_size_per_partition
//
self
.
quant_config
.
group_size
),
output_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
zeros_and_scales
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
})
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
...
...
@@ -149,27 +170,47 @@ class AWQLinearMethod(LinearMethodBase):
set_weight_attrs
(
qzeros
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
layer
.
register_parameter
(
"zeros_and_scales"
,
zeros_and_scales
)
set_weight_attrs
(
zeros_and_scales
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qweight
=
layer
.
qweight
scales
=
layer
.
scales
qzeros
=
layer
.
qzeros
pack_factor
=
self
.
quant_config
.
pack_factor
out_shape
=
(
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
]
*
pack_factor
,
))
zeros_and_scales
=
layer
.
zeros_and_scales
out_shape
=
(
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
0
]
*
1
,
))
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
# num_tokens >= threshold
FP16_MATMUL_HEURISTIC_CONDITION
=
x
.
shape
[
:
-
1
]
.
numel
()
>=
256
if
FP16_MATMUL_HEURISTIC_CONDITION
:
out
=
ops
.
awq_dequantize
(
qweight
,
scales
,
qzeros
,
0
,
0
,
0
)
out
=
torch
.
matmul
(
reshaped_x
,
out
)
m
=
reshaped_x
.
shape
[
0
]
k
=
reshaped_
x
.
shape
[
-
1
]
n
=
qweight
.
shape
[
0
]
if
k
%
4096
==
0
:
padding_group
=
2
else
:
out
=
ops
.
awq_gemm
(
reshaped_x
,
qweight
,
scales
,
qzeros
,
pack_factor
)
padding_group
=
0
out
=
_ops
.
awq_gemm
(
reshaped_x
,
qweight
,
zeros_and_scales
,
m
,
n
,
k
,
self
.
quant_config
.
group_size
,
padding_group
,
AWQShareWorkSpace
.
awqworkshapce
,
AWQShareWorkSpace
.
awqworkshapcesize
)
#下面是采用rocblas的做法
# deqweight=_ops.dequant_w4_gemm_colmajor( #shape[n,k/8]--->[n,k]
# qweight,
# zeros_and_scales,
# k,
# n,
# self.quant_config.group_size)
# output=F.linear(reshaped_x, deqweight)
if
bias
is
not
None
:
out
.
add_
(
bias
)
return
out
.
reshape
(
out_shape
)
vllm/model_executor/model_loader/loader.py
View file @
5f5ddc3d
...
...
@@ -259,7 +259,7 @@ class DefaultModelLoader(BaseModelLoader):
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
if
quant_method
is
not
None
and
quant_method
!=
"awq"
:
:
quant_method
.
process_weights_after_loading
(
module
)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
...
...
vllm/model_executor/models/llama.py
View file @
5f5ddc3d
...
...
@@ -367,6 +367,18 @@ class LlamaForCausalLM(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
if
self
.
quant_method
==
"awq"
:
try
:
import
lmslim
except
ValueError
as
e
:
raise
RuntimeError
(
"please install lmslim first for awq
\n
"
)
from
e
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
...
...
@@ -476,7 +488,49 @@ class LlamaForCausalLM(nn.Module):
weight
.
data
.
copy_
(
_weight
)
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
if
self
.
quant_method
==
"awq"
:
from
lmslim
import
quant_ops
as
_ops
lay_key_words
=
[
"self_attn.qkv_proj.qweight"
,
"self_attn.o_proj.qweight"
,
"mlp.gate_up_proj.qweight"
,
"mlp.down_proj.qweight"
]
combined_words
=
"|"
.
join
(
lay_key_words
)
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
qweight
=
params_dict
[
layername
]
qzeros
=
params_dict
[
layername
.
replace
(
"qweight"
,
"qzeros"
)]
scales
=
params_dict
[
layername
.
replace
(
"qweight"
,
"scales"
)]
zeros_and_scalse
=
params_dict
[
layername
.
replace
(
"qweight"
,
"zeros_and_scales"
)]
group_size
=
self
.
quant_config
.
group_size
dim_n
=
scales
.
data
.
shape
[
1
]
dim_k
=
qweight
.
data
.
shape
[
0
]
pad_group
=
2
_qw
,
_sz
=
_ops
.
convert_s4
(
qweight
,
qzeros
,
scales
,
int
(
group_size
))
sz
=
_ops
.
sz_permute
(
_sz
).
reshape
(
-
1
,
dim_n
)
zeros_and_scalse
.
data
.
copy_
(
sz
)
qweight
.
data
.
copy_
(
_qw
)
#reshape
zeros_and_scalse
.
data
=
zeros_and_scalse
.
reshape
(
dim_n
,
-
1
)
#[k/greop_size,n]------>[n,k/group_size]
qweight
.
data
=
qweight
.
data
.
reshape
(
dim_n
,
-
1
)
#[k,n/8]---->[n,k/8]
if
dim_k
%
4096
==
0
:
zeros_and_scalse_pad
=
torch
.
zeros
(
dim_n
,
pad_group
,
dtype
=
torch
.
int32
).
cuda
()
zeros_and_scalse
.
data
=
torch
.
cat
((
zeros_and_scalse
.
data
,
zeros_and_scalse_pad
),
dim
=
1
).
contiguous
()
qweight_pad
=
torch
.
zeros
(
dim_n
,
int
(
group_size
//
4
),
dtype
=
torch
.
int32
).
cuda
()
qweight
.
data
=
torch
.
cat
((
qweight
.
data
,
qweight_pad
),
dim
=
1
).
contiguous
()
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment