Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
8d17774f
Unverified
Commit
8d17774f
authored
Nov 18, 2023
by
Woosuk Kwon
Committed by
GitHub
Nov 18, 2023
Browse files
Add AWQ support for all models (#1714)
parent
e946260c
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
90 additions
and
17 deletions
+90
-17
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+47
-5
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+3
-0
vllm/model_executor/layers/quantization/base_config.py
vllm/model_executor/layers/quantization/base_config.py
+8
-0
vllm/model_executor/layers/quantization/squeezellm.py
vllm/model_executor/layers/quantization/squeezellm.py
+3
-0
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+3
-2
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+4
-1
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+3
-1
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+3
-1
vllm/model_executor/models/gpt_j.py
vllm/model_executor/models/gpt_j.py
+3
-1
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+3
-1
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+2
-1
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+5
-3
vllm/model_executor/models/phi_1_5.py
vllm/model_executor/models/phi_1_5.py
+3
-1
No files found.
vllm/model_executor/layers/activation.py
View file @
8d17774f
"""Custom activation functions."""
"""Custom activation functions."""
from
typing
import
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm
import
activation_ops
from
vllm
import
activation_ops
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
class
SiluAndMul
(
nn
.
Module
):
class
SiluAndMul
(
nn
.
Module
):
...
@@ -39,6 +42,27 @@ class FastGELU(nn.Module):
...
@@ -39,6 +42,27 @@ class FastGELU(nn.Module):
return
out
return
out
class
ScaledActivation
(
nn
.
Module
):
"""An activation function with post-scale parameters.
This is used for some quantization methods like AWQ.
"""
def
__init__
(
self
,
act_module
:
nn
.
Module
,
hidden_size
:
int
,
params_dtype
:
torch
.
dtype
,
):
super
().
__init__
()
self
.
act
=
act_module
self
.
scales
=
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
dtype
=
params_dtype
,
device
=
"cuda"
))
def
forward
(
self
,
x
:
torch
.
Tensor
):
return
self
.
act
(
x
)
/
self
.
scales
_ACTIVATION_REGISTRY
=
{
_ACTIVATION_REGISTRY
=
{
"gelu"
:
nn
.
GELU
(),
"gelu"
:
nn
.
GELU
(),
"gelu_fast"
:
FastGELU
(),
"gelu_fast"
:
FastGELU
(),
...
@@ -48,9 +72,27 @@ _ACTIVATION_REGISTRY = {
...
@@ -48,9 +72,27 @@ _ACTIVATION_REGISTRY = {
}
}
def
get_act_fn
(
act_fn
:
str
)
->
nn
.
Module
:
def
get_act_fn
(
act_fn_name
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
intermediate_size
:
Optional
[
int
]
=
None
,
)
->
nn
.
Module
:
"""Get an activation function by name."""
"""Get an activation function by name."""
act_fn
=
act_fn
.
lower
()
act_fn_name
=
act_fn_name
.
lower
()
if
act_fn
in
_ACTIVATION_REGISTRY
:
if
act_fn_name
not
in
_ACTIVATION_REGISTRY
:
return
_ACTIVATION_REGISTRY
[
act_fn
]
raise
ValueError
(
raise
ValueError
(
f
"Activation function
{
act_fn
!
r
}
is not supported."
)
f
"Activation function
{
act_fn_name
!
r
}
is not supported."
)
act_fn
=
_ACTIVATION_REGISTRY
[
act_fn_name
]
if
quant_config
is
not
None
:
if
act_fn_name
in
quant_config
.
get_scaled_act_names
():
if
intermediate_size
is
None
:
raise
ValueError
(
"intermediate_size must be specified for scaled "
"activation functions."
)
return
ScaledActivation
(
act_fn
,
intermediate_size
,
params_dtype
=
torch
.
get_default_dtype
(),
)
return
act_fn
vllm/model_executor/layers/quantization/awq.py
View file @
8d17774f
...
@@ -63,6 +63,9 @@ class AWQConfig(QuantizationConfig):
...
@@ -63,6 +63,9 @@ class AWQConfig(QuantizationConfig):
def
get_linear_method
(
self
)
->
"AWQLinearMethod"
:
def
get_linear_method
(
self
)
->
"AWQLinearMethod"
:
return
AWQLinearMethod
(
self
)
return
AWQLinearMethod
(
self
)
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[
"gelu"
,
"gelu_fast"
,
"gelu_new"
,
"gelu_pytorch_tanh"
]
class
AWQLinearMethod
(
LinearMethodBase
):
class
AWQLinearMethod
(
LinearMethodBase
):
"""Linear method for AWQ.
"""Linear method for AWQ.
...
...
vllm/model_executor/layers/quantization/base_config.py
View file @
8d17774f
...
@@ -54,3 +54,11 @@ class QuantizationConfig(ABC):
...
@@ -54,3 +54,11 @@ class QuantizationConfig(ABC):
def
get_linear_method
(
self
)
->
LinearMethodBase
:
def
get_linear_method
(
self
)
->
LinearMethodBase
:
"""Get the linear method to use for the quantized linear layer."""
"""Get the linear method to use for the quantized linear layer."""
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
"""Returns the activation function names that should be post-scaled.
For now, this is only used by AWQ.
"""
raise
NotImplementedError
vllm/model_executor/layers/quantization/squeezellm.py
View file @
8d17774f
...
@@ -52,6 +52,9 @@ class SqueezeLLMConfig(QuantizationConfig):
...
@@ -52,6 +52,9 @@ class SqueezeLLMConfig(QuantizationConfig):
def
get_linear_method
(
self
)
->
"SqueezeLLMLinearMethod"
:
def
get_linear_method
(
self
)
->
"SqueezeLLMLinearMethod"
:
return
SqueezeLLMLinearMethod
(
self
)
return
SqueezeLLMLinearMethod
(
self
)
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
SqueezeLLMLinearMethod
(
LinearMethodBase
):
class
SqueezeLLMLinearMethod
(
LinearMethodBase
):
"""Linear method for SqueezeLLM.
"""Linear method for SqueezeLLM.
...
...
vllm/model_executor/models/bloom.py
View file @
8d17774f
...
@@ -145,7 +145,8 @@ class BloomMLP(nn.Module):
...
@@ -145,7 +145,8 @@ class BloomMLP(nn.Module):
4
*
hidden_size
,
4
*
hidden_size
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
act
=
get_act_fn
(
"gelu"
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
gelu_impl
=
get_act_fn
(
"gelu"
,
quant_config
,
4
*
hidden_size
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
self
.
dense_4h_to_h
=
RowParallelLinear
(
4
*
hidden_size
,
4
*
hidden_size
,
hidden_size
,
hidden_size
,
...
@@ -154,7 +155,7 @@ class BloomMLP(nn.Module):
...
@@ -154,7 +155,7 @@ class BloomMLP(nn.Module):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
,
_
=
self
.
dense_h_to_4h
(
x
)
x
,
_
=
self
.
dense_h_to_4h
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
gelu_impl
(
x
)
x
,
_
=
self
.
dense_4h_to_h
(
x
)
x
,
_
=
self
.
dense_4h_to_h
(
x
)
return
x
return
x
...
...
vllm/model_executor/models/falcon.py
View file @
8d17774f
...
@@ -27,6 +27,7 @@ from torch.nn import LayerNorm
...
@@ -27,6 +27,7 @@ from torch.nn import LayerNorm
from
transformers
import
FalconConfig
as
HF_FalconConfig
from
transformers
import
FalconConfig
as
HF_FalconConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
(
PagedAttention
,
from
vllm.model_executor.layers.attention
import
(
PagedAttention
,
PagedAttentionWithALiBi
,
PagedAttentionWithALiBi
,
PagedAttentionWithRoPE
)
PagedAttentionWithRoPE
)
...
@@ -131,6 +132,7 @@ class FalconAttention(nn.Module):
...
@@ -131,6 +132,7 @@ class FalconAttention(nn.Module):
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
config
.
bias
,
bias
=
config
.
bias
,
skip_bias_add
=
True
,
skip_bias_add
=
True
,
linear_method
=
linear_method
,
reduce_results
=
self
.
reduce_row_parallel_results
)
reduce_results
=
self
.
reduce_row_parallel_results
)
self
.
use_rotary
=
config
.
rotary
self
.
use_rotary
=
config
.
rotary
...
@@ -206,7 +208,8 @@ class FalconMLP(nn.Module):
...
@@ -206,7 +208,8 @@ class FalconMLP(nn.Module):
bias
=
config
.
bias
,
bias
=
config
.
bias
,
skip_bias_add
=
True
,
skip_bias_add
=
True
,
linear_method
=
linear_method
)
linear_method
=
linear_method
)
self
.
act
=
nn
.
GELU
()
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
"gelu"
,
quant_config
,
4
*
hidden_size
)
self
.
reduce_row_parallel_results
=
not
(
config
.
new_decoder_architecture
self
.
reduce_row_parallel_results
=
not
(
config
.
new_decoder_architecture
or
config
.
parallel_attn
)
or
config
.
parallel_attn
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
self
.
dense_4h_to_h
=
RowParallelLinear
(
...
...
vllm/model_executor/models/gpt2.py
View file @
8d17774f
...
@@ -118,7 +118,9 @@ class GPT2MLP(nn.Module):
...
@@ -118,7 +118,9 @@ class GPT2MLP(nn.Module):
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
act
=
get_act_fn
(
config
.
activation_function
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
intermediate_size
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
c_fc
(
hidden_states
)
hidden_states
,
_
=
self
.
c_fc
(
hidden_states
)
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
8d17774f
...
@@ -137,7 +137,9 @@ class GPTBigMLP(nn.Module):
...
@@ -137,7 +137,9 @@ class GPTBigMLP(nn.Module):
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
act
=
get_act_fn
(
config
.
activation_function
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
intermediate_size
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
c_fc
(
hidden_states
)
hidden_states
,
_
=
self
.
c_fc
(
hidden_states
)
...
...
vllm/model_executor/models/gpt_j.py
View file @
8d17774f
...
@@ -128,7 +128,9 @@ class GPTJMLP(nn.Module):
...
@@ -128,7 +128,9 @@ class GPTJMLP(nn.Module):
hidden_size
,
hidden_size
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
act
=
get_act_fn
(
config
.
activation_function
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
intermediate_size
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
fc_in
(
hidden_states
)
hidden_states
,
_
=
self
.
fc_in
(
hidden_states
)
...
...
vllm/model_executor/models/gpt_neox.py
View file @
8d17774f
...
@@ -124,7 +124,9 @@ class GPTNeoXMLP(nn.Module):
...
@@ -124,7 +124,9 @@ class GPTNeoXMLP(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
act
=
get_act_fn
(
config
.
hidden_act
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
hidden_act
,
quant_config
,
config
.
intermediate_size
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
hidden_states
,
_
=
self
.
dense_h_to_4h
(
hidden_states
)
hidden_states
,
_
=
self
.
dense_h_to_4h
(
hidden_states
)
...
...
vllm/model_executor/models/mpt.py
View file @
8d17774f
...
@@ -130,7 +130,8 @@ class MPTMLP(nn.Module):
...
@@ -130,7 +130,8 @@ class MPTMLP(nn.Module):
bias
=
not
config
.
no_bias
,
bias
=
not
config
.
no_bias
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
act
=
get_act_fn
(
"gelu"
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
"gelu"
,
quant_config
,
intermediate_size
)
self
.
down_proj
=
RowParallelLinear
(
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
intermediate_size
,
hidden_size
,
hidden_size
,
...
...
vllm/model_executor/models/opt.py
View file @
8d17774f
...
@@ -129,7 +129,9 @@ class OPTDecoderLayer(nn.Module):
...
@@ -129,7 +129,9 @@ class OPTDecoderLayer(nn.Module):
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
do_layer_norm_before
=
config
.
do_layer_norm_before
self
.
do_layer_norm_before
=
config
.
do_layer_norm_before
self
.
activation_fn
=
get_act_fn
(
config
.
activation_function
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
activation_fn
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
config
.
ffn_dim
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
,
self
.
embed_dim
,
...
@@ -251,7 +253,7 @@ class OPTDecoder(nn.Module):
...
@@ -251,7 +253,7 @@ class OPTDecoder(nn.Module):
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
pos_embeds
=
self
.
embed_positions
(
positions
)
pos_embeds
=
self
.
embed_positions
(
positions
)
if
self
.
project_in
is
not
None
:
if
self
.
project_in
is
not
None
:
inputs_embeds
=
self
.
project_in
(
inputs_embeds
)
inputs_embeds
,
_
=
self
.
project_in
(
inputs_embeds
)
hidden_states
=
inputs_embeds
+
pos_embeds
hidden_states
=
inputs_embeds
+
pos_embeds
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
...
@@ -266,7 +268,7 @@ class OPTDecoder(nn.Module):
...
@@ -266,7 +268,7 @@ class OPTDecoder(nn.Module):
if
self
.
final_layer_norm
is
not
None
:
if
self
.
final_layer_norm
is
not
None
:
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
if
self
.
project_out
is
not
None
:
if
self
.
project_out
is
not
None
:
hidden_states
=
self
.
project_out
(
hidden_states
)
hidden_states
,
_
=
self
.
project_out
(
hidden_states
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/phi_1_5.py
View file @
8d17774f
...
@@ -168,7 +168,9 @@ class PhiMLP(nn.Module):
...
@@ -168,7 +168,9 @@ class PhiMLP(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
act
=
get_act_fn
(
config
.
activation_function
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
n_inner
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment