Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
982c1545
Commit
982c1545
authored
Aug 05, 2024
by
gaoqiong
Browse files
add qwen awq support
parent
ec98d390
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
55 additions
and
13 deletions
+55
-13
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+1
-1
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+52
-0
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+2
-12
No files found.
vllm/model_executor/model_loader/loader.py
View file @
982c1545
...
@@ -259,7 +259,7 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -259,7 +259,7 @@ class DefaultModelLoader(BaseModelLoader):
for
_
,
module
in
model
.
named_modules
():
for
_
,
module
in
model
.
named_modules
():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
and
quant_method
!=
"awq"
:
:
if
quant_method
is
not
None
and
quant_method
!=
"awq"
:
quant_method
.
process_weights_after_loading
(
module
)
quant_method
.
process_weights_after_loading
(
module
)
# FIXME: Remove this after Mixtral is updated
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
# to use quant_method.
...
...
vllm/model_executor/models/qwen.py
View file @
982c1545
...
@@ -245,6 +245,17 @@ class QWenLMHeadModel(nn.Module):
...
@@ -245,6 +245,17 @@ class QWenLMHeadModel(nn.Module):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
if
self
.
quant_method
==
"awq"
:
try
:
import
lmslim
except
ValueError
as
e
:
raise
RuntimeError
(
"please install lmslim first for awq
\n
"
)
from
e
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
...
@@ -339,4 +350,45 @@ class QWenLMHeadModel(nn.Module):
...
@@ -339,4 +350,45 @@ class QWenLMHeadModel(nn.Module):
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
weight
.
data
=
weight
.
data
.
reshape
(
ori_shape
[
1
],
-
1
)
if
self
.
quant_method
==
"awq"
:
from
lmslim
import
quant_ops
as
_ops
lay_key_words
=
[
"attn.c_attn.qweight"
,
"attn.c_proj.qweight"
,
"mlp.gate_up_proj.qweight"
,
"mlp.c_proj.qweight"
]
combined_words
=
"|"
.
join
(
lay_key_words
)
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
qweight
=
params_dict
[
layername
]
qzeros
=
params_dict
[
layername
.
replace
(
"qweight"
,
"qzeros"
)]
scales
=
params_dict
[
layername
.
replace
(
"qweight"
,
"scales"
)]
zeros_and_scalse
=
params_dict
[
layername
.
replace
(
"qweight"
,
"zeros_and_scales"
)]
group_size
=
self
.
quant_config
.
group_size
dim_n
=
scales
.
data
.
shape
[
1
]
dim_k
=
qweight
.
data
.
shape
[
0
]
pad_group
=
2
_qw
,
_sz
=
_ops
.
convert_s4
(
qweight
,
qzeros
,
scales
,
int
(
group_size
))
sz
=
_ops
.
sz_permute
(
_sz
).
reshape
(
-
1
,
dim_n
)
zeros_and_scalse
.
data
.
copy_
(
sz
)
qweight
.
data
.
copy_
(
_qw
)
#reshape
zeros_and_scalse
.
data
=
zeros_and_scalse
.
reshape
(
dim_n
,
-
1
)
#[k/greop_size,n]------>[n,k/group_size]
qweight
.
data
=
qweight
.
data
.
reshape
(
dim_n
,
-
1
)
#[k,n/8]---->[n,k/8]
if
dim_k
%
4096
==
0
:
zeros_and_scalse_pad
=
torch
.
zeros
(
dim_n
,
pad_group
,
dtype
=
torch
.
int32
).
cuda
()
zeros_and_scalse
.
data
=
torch
.
cat
((
zeros_and_scalse
.
data
,
zeros_and_scalse_pad
),
dim
=
1
).
contiguous
()
qweight_pad
=
torch
.
zeros
(
dim_n
,
int
(
group_size
//
4
),
dtype
=
torch
.
int32
).
cuda
()
qweight
.
data
=
torch
.
cat
((
qweight
.
data
,
qweight_pad
),
dim
=
1
).
contiguous
()
vllm/model_executor/models/qwen2.py
View file @
982c1545
...
@@ -326,8 +326,8 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -326,8 +326,8 @@ class Qwen2ForCausalLM(nn.Module):
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
self
.
quant_method
=
None
self
.
quant_method
=
None
if
quant_config
is
not
None
:
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
...
@@ -440,7 +440,7 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -440,7 +440,7 @@ class Qwen2ForCausalLM(nn.Module):
if
self
.
quant_method
==
"awq"
:
if
self
.
quant_method
==
"awq"
:
from
lmslim
import
quant_ops
as
_ops
from
lmslim
import
quant_ops
as
_ops
# 对weight进行处理转置处理
lay_key_words
=
[
lay_key_words
=
[
"self_attn.qkv_proj.qweight"
,
"self_attn.qkv_proj.qweight"
,
"self_attn.o_proj.qweight"
,
"self_attn.o_proj.qweight"
,
...
@@ -453,7 +453,6 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -453,7 +453,6 @@ class Qwen2ForCausalLM(nn.Module):
matches
=
re
.
findall
(
combined_words
,
layername
)
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
matches
:
#只对.qweight做了匹配,但是对对应的scale和qzeros都做了处理
qweight
=
params_dict
[
layername
]
qweight
=
params_dict
[
layername
]
qzeros
=
params_dict
[
layername
.
replace
(
"qweight"
,
"qzeros"
)]
qzeros
=
params_dict
[
layername
.
replace
(
"qweight"
,
"qzeros"
)]
scales
=
params_dict
[
layername
.
replace
(
"qweight"
,
"scales"
)]
scales
=
params_dict
[
layername
.
replace
(
"qweight"
,
"scales"
)]
...
@@ -465,18 +464,10 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -465,18 +464,10 @@ class Qwen2ForCausalLM(nn.Module):
dim_k
=
qweight
.
data
.
shape
[
0
]
dim_k
=
qweight
.
data
.
shape
[
0
]
pad_group
=
2
pad_group
=
2
#对qweight和qzeros以及scales进行pad
#qweight[k,n/8]--->[k+group_size*2,n/8]
#qzeros [k/group_size+2,n/8]
#scales [k/group_size+2,n]
#给weight进行转置和zeros_and_scales打包
_qw
,
_sz
=
_ops
.
convert_s4
(
qweight
,
qzeros
,
scales
,
int
(
group_size
))
_qw
,
_sz
=
_ops
.
convert_s4
(
qweight
,
qzeros
,
scales
,
int
(
group_size
))
#给sz转置(转置之后但是暂时保留原来的shape信息)
sz
=
_ops
.
sz_permute
(
_sz
).
reshape
(
-
1
,
dim_n
)
sz
=
_ops
.
sz_permute
(
_sz
).
reshape
(
-
1
,
dim_n
)
#数据拷贝
zeros_and_scalse
.
data
.
copy_
(
sz
)
zeros_and_scalse
.
data
.
copy_
(
sz
)
qweight
.
data
.
copy_
(
_qw
)
qweight
.
data
.
copy_
(
_qw
)
...
@@ -484,7 +475,6 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -484,7 +475,6 @@ class Qwen2ForCausalLM(nn.Module):
zeros_and_scalse
.
data
=
zeros_and_scalse
.
reshape
(
dim_n
,
-
1
)
#[k/greop_size,n]------>[n,k/group_size]
zeros_and_scalse
.
data
=
zeros_and_scalse
.
reshape
(
dim_n
,
-
1
)
#[k/greop_size,n]------>[n,k/group_size]
qweight
.
data
=
qweight
.
data
.
reshape
(
dim_n
,
-
1
)
#[k,n/8]---->[n,k/8]
qweight
.
data
=
qweight
.
data
.
reshape
(
dim_n
,
-
1
)
#[k,n/8]---->[n,k/8]
#对qweight 与zeros_and_scalse 进行pad
if
dim_k
%
4096
==
0
:
if
dim_k
%
4096
==
0
:
zeros_and_scalse_pad
=
torch
.
zeros
(
dim_n
,
pad_group
,
dtype
=
torch
.
int32
).
cuda
()
zeros_and_scalse_pad
=
torch
.
zeros
(
dim_n
,
pad_group
,
dtype
=
torch
.
int32
).
cuda
()
zeros_and_scalse
.
data
=
torch
.
cat
((
zeros_and_scalse
.
data
,
zeros_and_scalse_pad
),
dim
=
1
).
contiguous
()
zeros_and_scalse
.
data
=
torch
.
cat
((
zeros_and_scalse
.
data
,
zeros_and_scalse_pad
),
dim
=
1
).
contiguous
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment