Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
8d17774f
"Python/Classifier_OffloadFalse.py" did not exist on "a30cc948a0d7ac954fc53aad73e2cf307bd1e072"
Unverified
Commit
8d17774f
authored
Nov 18, 2023
by
Woosuk Kwon
Committed by
GitHub
Nov 18, 2023
Browse files
Add AWQ support for all models (#1714)
parent
e946260c
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
90 additions
and
17 deletions
+90
-17
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+47
-5
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+3
-0
vllm/model_executor/layers/quantization/base_config.py
vllm/model_executor/layers/quantization/base_config.py
+8
-0
vllm/model_executor/layers/quantization/squeezellm.py
vllm/model_executor/layers/quantization/squeezellm.py
+3
-0
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+3
-2
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+4
-1
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+3
-1
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+3
-1
vllm/model_executor/models/gpt_j.py
vllm/model_executor/models/gpt_j.py
+3
-1
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+3
-1
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+2
-1
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+5
-3
vllm/model_executor/models/phi_1_5.py
vllm/model_executor/models/phi_1_5.py
+3
-1
No files found.
vllm/model_executor/layers/activation.py
View file @
8d17774f
"""Custom activation functions."""
"""Custom activation functions."""
from
typing
import
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm
import
activation_ops
from
vllm
import
activation_ops
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
class
SiluAndMul
(
nn
.
Module
):
class
SiluAndMul
(
nn
.
Module
):
...
@@ -39,6 +42,27 @@ class FastGELU(nn.Module):
...
@@ -39,6 +42,27 @@ class FastGELU(nn.Module):
return
out
return
out
class
ScaledActivation
(
nn
.
Module
):
"""An activation function with post-scale parameters.
This is used for some quantization methods like AWQ.
"""
def
__init__
(
self
,
act_module
:
nn
.
Module
,
hidden_size
:
int
,
params_dtype
:
torch
.
dtype
,
):
super
().
__init__
()
self
.
act
=
act_module
self
.
scales
=
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
dtype
=
params_dtype
,
device
=
"cuda"
))
def
forward
(
self
,
x
:
torch
.
Tensor
):
return
self
.
act
(
x
)
/
self
.
scales
_ACTIVATION_REGISTRY
=
{
_ACTIVATION_REGISTRY
=
{
"gelu"
:
nn
.
GELU
(),
"gelu"
:
nn
.
GELU
(),
"gelu_fast"
:
FastGELU
(),
"gelu_fast"
:
FastGELU
(),
...
@@ -48,9 +72,27 @@ _ACTIVATION_REGISTRY = {
...
@@ -48,9 +72,27 @@ _ACTIVATION_REGISTRY = {
}
}
def
get_act_fn
(
act_fn
:
str
)
->
nn
.
Module
:
def
get_act_fn
(
act_fn_name
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
intermediate_size
:
Optional
[
int
]
=
None
,
)
->
nn
.
Module
:
"""Get an activation function by name."""
"""Get an activation function by name."""
act_fn
=
act_fn
.
lower
()
act_fn_name
=
act_fn_name
.
lower
()
if
act_fn
in
_ACTIVATION_REGISTRY
:
if
act_fn_name
not
in
_ACTIVATION_REGISTRY
:
return
_ACTIVATION_REGISTRY
[
act_fn
]
raise
ValueError
(
raise
ValueError
(
f
"Activation function
{
act_fn
!
r
}
is not supported."
)
f
"Activation function
{
act_fn_name
!
r
}
is not supported."
)
act_fn
=
_ACTIVATION_REGISTRY
[
act_fn_name
]
if
quant_config
is
not
None
:
if
act_fn_name
in
quant_config
.
get_scaled_act_names
():
if
intermediate_size
is
None
:
raise
ValueError
(
"intermediate_size must be specified for scaled "
"activation functions."
)
return
ScaledActivation
(
act_fn
,
intermediate_size
,
params_dtype
=
torch
.
get_default_dtype
(),
)
return
act_fn
vllm/model_executor/layers/quantization/awq.py
View file @
8d17774f
...
@@ -63,6 +63,9 @@ class AWQConfig(QuantizationConfig):
...
@@ -63,6 +63,9 @@ class AWQConfig(QuantizationConfig):
def
get_linear_method
(
self
)
->
"AWQLinearMethod"
:
def
get_linear_method
(
self
)
->
"AWQLinearMethod"
:
return
AWQLinearMethod
(
self
)
return
AWQLinearMethod
(
self
)
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[
"gelu"
,
"gelu_fast"
,
"gelu_new"
,
"gelu_pytorch_tanh"
]
class
AWQLinearMethod
(
LinearMethodBase
):
class
AWQLinearMethod
(
LinearMethodBase
):
"""Linear method for AWQ.
"""Linear method for AWQ.
...
...
vllm/model_executor/layers/quantization/base_config.py
View file @
8d17774f
...
@@ -54,3 +54,11 @@ class QuantizationConfig(ABC):
...
@@ -54,3 +54,11 @@ class QuantizationConfig(ABC):
def
get_linear_method
(
self
)
->
LinearMethodBase
:
def
get_linear_method
(
self
)
->
LinearMethodBase
:
"""Get the linear method to use for the quantized linear layer."""
"""Get the linear method to use for the quantized linear layer."""
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
"""Returns the activation function names that should be post-scaled.
For now, this is only used by AWQ.
"""
raise
NotImplementedError
vllm/model_executor/layers/quantization/squeezellm.py
View file @
8d17774f
...
@@ -52,6 +52,9 @@ class SqueezeLLMConfig(QuantizationConfig):
...
@@ -52,6 +52,9 @@ class SqueezeLLMConfig(QuantizationConfig):
def
get_linear_method
(
self
)
->
"SqueezeLLMLinearMethod"
:
def
get_linear_method
(
self
)
->
"SqueezeLLMLinearMethod"
:
return
SqueezeLLMLinearMethod
(
self
)
return
SqueezeLLMLinearMethod
(
self
)
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
SqueezeLLMLinearMethod
(
LinearMethodBase
):
class
SqueezeLLMLinearMethod
(
LinearMethodBase
):
"""Linear method for SqueezeLLM.
"""Linear method for SqueezeLLM.
...
...
vllm/model_executor/models/bloom.py
View file @
8d17774f
...
@@ -145,7 +145,8 @@ class BloomMLP(nn.Module):
...
@@ -145,7 +145,8 @@ class BloomMLP(nn.Module):
4
*
hidden_size
,
4
*
hidden_size
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
act
=
get_act_fn
(
"gelu"
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
gelu_impl
=
get_act_fn
(
"gelu"
,
quant_config
,
4
*
hidden_size
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
self
.
dense_4h_to_h
=
RowParallelLinear
(
4
*
hidden_size
,
4
*
hidden_size
,
hidden_size
,
hidden_size
,
...
@@ -154,7 +155,7 @@ class BloomMLP(nn.Module):
...
@@ -154,7 +155,7 @@ class BloomMLP(nn.Module):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
,
_
=
self
.
dense_h_to_4h
(
x
)
x
,
_
=
self
.
dense_h_to_4h
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
gelu_impl
(
x
)
x
,
_
=
self
.
dense_4h_to_h
(
x
)
x
,
_
=
self
.
dense_4h_to_h
(
x
)
return
x
return
x
...
...
vllm/model_executor/models/falcon.py
View file @
8d17774f
...
@@ -27,6 +27,7 @@ from torch.nn import LayerNorm
...
@@ -27,6 +27,7 @@ from torch.nn import LayerNorm
from
transformers
import
FalconConfig
as
HF_FalconConfig
from
transformers
import
FalconConfig
as
HF_FalconConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
(
PagedAttention
,
from
vllm.model_executor.layers.attention
import
(
PagedAttention
,
PagedAttentionWithALiBi
,
PagedAttentionWithALiBi
,
PagedAttentionWithRoPE
)
PagedAttentionWithRoPE
)
...
@@ -131,6 +132,7 @@ class FalconAttention(nn.Module):
...
@@ -131,6 +132,7 @@ class FalconAttention(nn.Module):
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
config
.
bias
,
bias
=
config
.
bias
,
skip_bias_add
=
True
,
skip_bias_add
=
True
,
linear_method
=
linear_method
,
reduce_results
=
self
.
reduce_row_parallel_results
)
reduce_results
=
self
.
reduce_row_parallel_results
)
self
.
use_rotary
=
config
.
rotary
self
.
use_rotary
=
config
.
rotary
...
@@ -206,7 +208,8 @@ class FalconMLP(nn.Module):
...
@@ -206,7 +208,8 @@ class FalconMLP(nn.Module):
bias
=
config
.
bias
,
bias
=
config
.
bias
,
skip_bias_add
=
True
,
skip_bias_add
=
True
,
linear_method
=
linear_method
)
linear_method
=
linear_method
)
self
.
act
=
nn
.
GELU
()
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
"gelu"
,
quant_config
,
4
*
hidden_size
)
self
.
reduce_row_parallel_results
=
not
(
config
.
new_decoder_architecture
self
.
reduce_row_parallel_results
=
not
(
config
.
new_decoder_architecture
or
config
.
parallel_attn
)
or
config
.
parallel_attn
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
self
.
dense_4h_to_h
=
RowParallelLinear
(
...
...
vllm/model_executor/models/gpt2.py
View file @
8d17774f
...
@@ -118,7 +118,9 @@ class GPT2MLP(nn.Module):
...
@@ -118,7 +118,9 @@ class GPT2MLP(nn.Module):
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
act
=
get_act_fn
(
config
.
activation_function
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
intermediate_size
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
c_fc
(
hidden_states
)
hidden_states
,
_
=
self
.
c_fc
(
hidden_states
)
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
8d17774f
...
@@ -137,7 +137,9 @@ class GPTBigMLP(nn.Module):
...
@@ -137,7 +137,9 @@ class GPTBigMLP(nn.Module):
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
act
=
get_act_fn
(
config
.
activation_function
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
intermediate_size
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
c_fc
(
hidden_states
)
hidden_states
,
_
=
self
.
c_fc
(
hidden_states
)
...
...
vllm/model_executor/models/gpt_j.py
View file @
8d17774f
...
@@ -128,7 +128,9 @@ class GPTJMLP(nn.Module):
...
@@ -128,7 +128,9 @@ class GPTJMLP(nn.Module):
hidden_size
,
hidden_size
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
act
=
get_act_fn
(
config
.
activation_function
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
intermediate_size
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
fc_in
(
hidden_states
)
hidden_states
,
_
=
self
.
fc_in
(
hidden_states
)
...
...
vllm/model_executor/models/gpt_neox.py
View file @
8d17774f
...
@@ -124,7 +124,9 @@ class GPTNeoXMLP(nn.Module):
...
@@ -124,7 +124,9 @@ class GPTNeoXMLP(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
act
=
get_act_fn
(
config
.
hidden_act
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
hidden_act
,
quant_config
,
config
.
intermediate_size
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
hidden_states
,
_
=
self
.
dense_h_to_4h
(
hidden_states
)
hidden_states
,
_
=
self
.
dense_h_to_4h
(
hidden_states
)
...
...
vllm/model_executor/models/mpt.py
View file @
8d17774f
...
@@ -130,7 +130,8 @@ class MPTMLP(nn.Module):
...
@@ -130,7 +130,8 @@ class MPTMLP(nn.Module):
bias
=
not
config
.
no_bias
,
bias
=
not
config
.
no_bias
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
act
=
get_act_fn
(
"gelu"
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
"gelu"
,
quant_config
,
intermediate_size
)
self
.
down_proj
=
RowParallelLinear
(
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
intermediate_size
,
hidden_size
,
hidden_size
,
...
...
vllm/model_executor/models/opt.py
View file @
8d17774f
...
@@ -129,7 +129,9 @@ class OPTDecoderLayer(nn.Module):
...
@@ -129,7 +129,9 @@ class OPTDecoderLayer(nn.Module):
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
do_layer_norm_before
=
config
.
do_layer_norm_before
self
.
do_layer_norm_before
=
config
.
do_layer_norm_before
self
.
activation_fn
=
get_act_fn
(
config
.
activation_function
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
activation_fn
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
config
.
ffn_dim
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
,
self
.
embed_dim
,
...
@@ -251,7 +253,7 @@ class OPTDecoder(nn.Module):
...
@@ -251,7 +253,7 @@ class OPTDecoder(nn.Module):
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
pos_embeds
=
self
.
embed_positions
(
positions
)
pos_embeds
=
self
.
embed_positions
(
positions
)
if
self
.
project_in
is
not
None
:
if
self
.
project_in
is
not
None
:
inputs_embeds
=
self
.
project_in
(
inputs_embeds
)
inputs_embeds
,
_
=
self
.
project_in
(
inputs_embeds
)
hidden_states
=
inputs_embeds
+
pos_embeds
hidden_states
=
inputs_embeds
+
pos_embeds
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
...
@@ -266,7 +268,7 @@ class OPTDecoder(nn.Module):
...
@@ -266,7 +268,7 @@ class OPTDecoder(nn.Module):
if
self
.
final_layer_norm
is
not
None
:
if
self
.
final_layer_norm
is
not
None
:
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
if
self
.
project_out
is
not
None
:
if
self
.
project_out
is
not
None
:
hidden_states
=
self
.
project_out
(
hidden_states
)
hidden_states
,
_
=
self
.
project_out
(
hidden_states
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/phi_1_5.py
View file @
8d17774f
...
@@ -168,7 +168,9 @@ class PhiMLP(nn.Module):
...
@@ -168,7 +168,9 @@ class PhiMLP(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
act
=
get_act_fn
(
config
.
activation_function
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
n_inner
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment