Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cc4b902f
"vscode:/vscode.git/clone" did not exist on "c29fb540ff90da720490daae58bb4bfe31a91125"
Commit
cc4b902f
authored
Oct 09, 2024
by
zhuwenwen
Browse files
增加AWQ相关环境变量控制
parent
35a8304d
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
28 additions
and
7 deletions
+28
-7
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+10
-4
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+12
-0
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+2
-1
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+2
-1
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+2
-1
No files found.
vllm/model_executor/layers/quantization/awq.py
View file @
cc4b902f
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
import
torch
import
os
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
...
@@ -98,6 +99,8 @@ class AWQLinearMethod(LinearMethodBase):
...
@@ -98,6 +99,8 @@ class AWQLinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
AWQConfig
):
def
__init__
(
self
,
quant_config
:
AWQConfig
):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
awqsingleton
=
AWQShareWorkSpace
()
self
.
awqsingleton
=
AWQShareWorkSpace
()
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
AWQ_CK_GEMMBS
=
int
(
os
.
getenv
(
'AWQ_CK_GEMMBS'
,
'20000'
))
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
input_size_per_partition
:
int
,
...
@@ -190,12 +193,15 @@ class AWQLinearMethod(LinearMethodBase):
...
@@ -190,12 +193,15 @@ class AWQLinearMethod(LinearMethodBase):
k
=
reshaped_x
.
shape
[
-
1
]
k
=
reshaped_x
.
shape
[
-
1
]
n
=
qweight
.
shape
[
0
]
n
=
qweight
.
shape
[
0
]
if
k
%
4096
==
0
:
if
self
.
use_awq_pad
:
if
k
%
4096
==
0
:
padding_group
=
2
padding_group
=
2
else
:
else
:
padding_group
=
0
padding_group
=
0
else
:
padding_group
=
0
if
m
<
20000
:
if
m
<=
self
.
AWQ_CK_GEMMBS
:
out
=
ops
.
awq_gemm
(
reshaped_x
,
out
=
ops
.
awq_gemm
(
reshaped_x
,
qweight
,
qweight
,
zeros_and_scales
,
zeros_and_scales
,
...
@@ -208,7 +214,7 @@ class AWQLinearMethod(LinearMethodBase):
...
@@ -208,7 +214,7 @@ class AWQLinearMethod(LinearMethodBase):
self
.
awqsingleton
.
awqworkshapcesize
)
self
.
awqsingleton
.
awqworkshapcesize
)
else
:
else
:
#下面是采用rocblas的做法
#下面是采用rocblas的做法
deqweight
=
ops
.
dequant_w4_gemm_colmajor
(
#shape[n,k/8]--->[n,k]
deqweight
=
ops
.
dequant_w4_gemm_colmajor
(
#
shape[n,
k/8]
--->
[n,k]
qweight
,
qweight
,
zeros_and_scales
,
zeros_and_scales
,
k
,
k
,
...
...
vllm/model_executor/model_loader/utils.py
View file @
cc4b902f
...
@@ -30,10 +30,22 @@ def get_model_architecture(
...
@@ -30,10 +30,22 @@ def get_model_architecture(
os
.
environ
[
'GEMM_PAD'
]
=
'0'
os
.
environ
[
'GEMM_PAD'
]
=
'0'
if
os
.
getenv
(
'FA_PAD'
)
!=
'1'
:
if
os
.
getenv
(
'FA_PAD'
)
!=
'1'
:
os
.
environ
[
'FA_PAD'
]
=
'0'
os
.
environ
[
'FA_PAD'
]
=
'0'
try
:
if
os
.
getenv
(
'AWQ_PAD'
)
==
'0'
or
((
torch
.
cuda
.
isCurrentDeviceEco
(
torch
.
cuda
.
current_device
()))
and
os
.
getenv
(
'AWQ_PAD'
)
==
None
):
os
.
environ
[
'AWQ_PAD'
]
=
'0'
else
:
os
.
environ
[
'AWQ_PAD'
]
=
'1'
except
Exception
as
e
:
print
(
"Info: this version torch cannot get eco device info.
\n
"
)
if
os
.
getenv
(
'AWQ_PAD'
)
!=
'0'
:
os
.
environ
[
'AWQ_PAD'
]
=
'1'
else
:
os
.
environ
[
'AWQ_PAD'
]
=
'0'
else
:
else
:
os
.
environ
[
'LLAMA_NN'
]
=
'0'
os
.
environ
[
'LLAMA_NN'
]
=
'0'
os
.
environ
[
'GEMM_PAD'
]
=
'0'
os
.
environ
[
'GEMM_PAD'
]
=
'0'
os
.
environ
[
'FA_PAD'
]
=
'0'
os
.
environ
[
'FA_PAD'
]
=
'0'
os
.
environ
[
'AWQ_PAD'
]
=
'0'
# Special handling for quantized Mixtral.
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
# FIXME(woosuk): This is a temporary hack.
...
...
vllm/model_executor/models/llama.py
View file @
cc4b902f
...
@@ -457,6 +457,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -457,6 +457,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
def
forward
(
def
forward
(
self
,
self
,
...
@@ -633,7 +634,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -633,7 +634,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
zeros_and_scalse
.
data
=
zeros_and_scalse
.
reshape
(
dim_n
,
-
1
)
#[k/greop_size,n]------>[n,k/group_size]
zeros_and_scalse
.
data
=
zeros_and_scalse
.
reshape
(
dim_n
,
-
1
)
#[k/greop_size,n]------>[n,k/group_size]
qweight
.
data
=
qweight
.
data
.
reshape
(
dim_n
,
-
1
)
#[k,n/8]---->[n,k/8]
qweight
.
data
=
qweight
.
data
.
reshape
(
dim_n
,
-
1
)
#[k,n/8]---->[n,k/8]
if
dim_k
%
4096
==
0
:
if
dim_k
%
4096
==
0
and
self
.
use_awq_pad
:
zeros_and_scalse_pad
=
torch
.
zeros
(
dim_n
,
pad_group
,
dtype
=
torch
.
int32
).
cuda
()
zeros_and_scalse_pad
=
torch
.
zeros
(
dim_n
,
pad_group
,
dtype
=
torch
.
int32
).
cuda
()
zeros_and_scalse
.
data
=
torch
.
cat
((
zeros_and_scalse
.
data
,
zeros_and_scalse_pad
),
dim
=
1
).
contiguous
()
zeros_and_scalse
.
data
=
torch
.
cat
((
zeros_and_scalse
.
data
,
zeros_and_scalse_pad
),
dim
=
1
).
contiguous
()
qweight_pad
=
torch
.
zeros
(
dim_n
,
int
(
group_size
//
4
),
dtype
=
torch
.
int32
).
cuda
()
qweight_pad
=
torch
.
zeros
(
dim_n
,
int
(
group_size
//
4
),
dtype
=
torch
.
int32
).
cuda
()
...
...
vllm/model_executor/models/qwen.py
View file @
cc4b902f
...
@@ -903,6 +903,7 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
...
@@ -903,6 +903,7 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
def
_get_image_input_type
(
def
_get_image_input_type
(
self
,
self
,
...
@@ -1085,7 +1086,7 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
...
@@ -1085,7 +1086,7 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
zeros_and_scalse
.
data
=
zeros_and_scalse
.
reshape
(
dim_n
,
-
1
)
#[k/greop_size,n]------>[n,k/group_size]
zeros_and_scalse
.
data
=
zeros_and_scalse
.
reshape
(
dim_n
,
-
1
)
#[k/greop_size,n]------>[n,k/group_size]
qweight
.
data
=
qweight
.
data
.
reshape
(
dim_n
,
-
1
)
#[k,n/8]---->[n,k/8]
qweight
.
data
=
qweight
.
data
.
reshape
(
dim_n
,
-
1
)
#[k,n/8]---->[n,k/8]
if
dim_k
%
4096
==
0
:
if
dim_k
%
4096
==
0
and
self
.
use_awq_pad
:
zeros_and_scalse_pad
=
torch
.
zeros
(
dim_n
,
pad_group
,
dtype
=
torch
.
int32
).
cuda
()
zeros_and_scalse_pad
=
torch
.
zeros
(
dim_n
,
pad_group
,
dtype
=
torch
.
int32
).
cuda
()
zeros_and_scalse
.
data
=
torch
.
cat
((
zeros_and_scalse
.
data
,
zeros_and_scalse_pad
),
dim
=
1
).
contiguous
()
zeros_and_scalse
.
data
=
torch
.
cat
((
zeros_and_scalse
.
data
,
zeros_and_scalse_pad
),
dim
=
1
).
contiguous
()
qweight_pad
=
torch
.
zeros
(
dim_n
,
int
(
group_size
//
4
),
dtype
=
torch
.
int32
).
cuda
()
qweight_pad
=
torch
.
zeros
(
dim_n
,
int
(
group_size
//
4
),
dtype
=
torch
.
int32
).
cuda
()
...
...
vllm/model_executor/models/qwen2.py
View file @
cc4b902f
...
@@ -378,6 +378,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
...
@@ -378,6 +378,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
def
forward
(
def
forward
(
self
,
self
,
...
@@ -537,7 +538,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
...
@@ -537,7 +538,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
zeros_and_scalse
.
data
=
zeros_and_scalse
.
reshape
(
dim_n
,
-
1
)
#[k/greop_size,n]------>[n,k/group_size]
zeros_and_scalse
.
data
=
zeros_and_scalse
.
reshape
(
dim_n
,
-
1
)
#[k/greop_size,n]------>[n,k/group_size]
qweight
.
data
=
qweight
.
data
.
reshape
(
dim_n
,
-
1
)
#[k,n/8]---->[n,k/8]
qweight
.
data
=
qweight
.
data
.
reshape
(
dim_n
,
-
1
)
#[k,n/8]---->[n,k/8]
if
dim_k
%
4096
==
0
:
if
dim_k
%
4096
==
0
and
self
.
use_awq_pad
:
zeros_and_scalse_pad
=
torch
.
zeros
(
dim_n
,
pad_group
,
dtype
=
torch
.
int32
).
cuda
()
zeros_and_scalse_pad
=
torch
.
zeros
(
dim_n
,
pad_group
,
dtype
=
torch
.
int32
).
cuda
()
zeros_and_scalse
.
data
=
torch
.
cat
((
zeros_and_scalse
.
data
,
zeros_and_scalse_pad
),
dim
=
1
).
contiguous
()
zeros_and_scalse
.
data
=
torch
.
cat
((
zeros_and_scalse
.
data
,
zeros_and_scalse_pad
),
dim
=
1
).
contiguous
()
qweight_pad
=
torch
.
zeros
(
dim_n
,
int
(
group_size
//
4
),
dtype
=
torch
.
int32
).
cuda
()
qweight_pad
=
torch
.
zeros
(
dim_n
,
int
(
group_size
//
4
),
dtype
=
torch
.
int32
).
cuda
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment