Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
cf35d8f3
"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "9c64184938b1523beee8006a51b311849116a796"
Unverified
Commit
cf35d8f3
authored
Nov 20, 2023
by
Woosuk Kwon
Committed by
GitHub
Nov 20, 2023
Browse files
[BugFix] Fix TP support for AWQ (#1731)
parent
4bb6b671
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
14 deletions
+38
-14
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+35
-11
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+3
-3
No files found.
vllm/model_executor/layers/activation.py
View file @
cf35d8f3
...
@@ -6,6 +6,10 @@ import torch.nn as nn
...
@@ -6,6 +6,10 @@ import torch.nn as nn
from
vllm
import
activation_ops
from
vllm
import
activation_ops
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.utils
import
divide
from
vllm.model_executor.utils
import
set_weight_attrs
class
SiluAndMul
(
nn
.
Module
):
class
SiluAndMul
(
nn
.
Module
):
...
@@ -51,17 +55,38 @@ class ScaledActivation(nn.Module):
...
@@ -51,17 +55,38 @@ class ScaledActivation(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
act_module
:
nn
.
Module
,
act_module
:
nn
.
Module
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
input_is_parallel
:
bool
=
True
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
act
=
act_module
self
.
act
=
act_module
if
input_is_parallel
:
tp_size
=
get_tensor_model_parallel_world_size
()
intermediate_size_per_partition
=
divide
(
intermediate_size
,
tp_size
)
else
:
intermediate_size_per_partition
=
intermediate_size
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
scales
=
nn
.
Parameter
(
self
.
scales
=
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
dtype
=
params_dtype
,
device
=
"cuda"
))
torch
.
empty
(
intermediate_size_per_partition
,
dtype
=
params_dtype
,
device
=
"cuda"
))
set_weight_attrs
(
self
.
scales
,
{
"weight_loader"
:
self
.
weight_loader
})
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
act
(
x
)
/
self
.
scales
return
self
.
act
(
x
)
/
self
.
scales
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
):
tp_rank
=
get_tensor_model_parallel_rank
()
param_data
=
param
.
data
shard_size
=
param_data
.
shape
[
0
]
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
0
,
start_idx
,
shard_size
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
_ACTIVATION_REGISTRY
=
{
_ACTIVATION_REGISTRY
=
{
"gelu"
:
nn
.
GELU
(),
"gelu"
:
nn
.
GELU
(),
...
@@ -76,6 +101,8 @@ def get_act_fn(
...
@@ -76,6 +101,8 @@ def get_act_fn(
act_fn_name
:
str
,
act_fn_name
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
intermediate_size
:
Optional
[
int
]
=
None
,
intermediate_size
:
Optional
[
int
]
=
None
,
input_is_parallel
:
bool
=
True
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
nn
.
Module
:
)
->
nn
.
Module
:
"""Get an activation function by name."""
"""Get an activation function by name."""
act_fn_name
=
act_fn_name
.
lower
()
act_fn_name
=
act_fn_name
.
lower
()
...
@@ -84,14 +111,11 @@ def get_act_fn(
...
@@ -84,14 +111,11 @@ def get_act_fn(
f
"Activation function
{
act_fn_name
!
r
}
is not supported."
)
f
"Activation function
{
act_fn_name
!
r
}
is not supported."
)
act_fn
=
_ACTIVATION_REGISTRY
[
act_fn_name
]
act_fn
=
_ACTIVATION_REGISTRY
[
act_fn_name
]
if
quant_config
is
not
None
and
act_fn_name
in
quant_config
.
get_scaled_act_names
(
if
(
quant_config
is
not
None
):
and
act_fn_name
in
quant_config
.
get_scaled_act_names
()
):
if
intermediate_size
is
None
:
if
intermediate_size
is
None
:
raise
ValueError
(
"intermediate_size must be specified for scaled "
raise
ValueError
(
"intermediate_size must be specified for scaled "
"activation functions."
)
"activation functions."
)
return
ScaledActivation
(
return
ScaledActivation
(
act_fn
,
intermediate_size
,
input_is_parallel
,
act_fn
,
params_dtype
)
intermediate_size
,
params_dtype
=
torch
.
get_default_dtype
(),
)
return
act_fn
return
act_fn
vllm/model_executor/models/opt.py
View file @
cf35d8f3
...
@@ -129,9 +129,6 @@ class OPTDecoderLayer(nn.Module):
...
@@ -129,9 +129,6 @@ class OPTDecoderLayer(nn.Module):
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
do_layer_norm_before
=
config
.
do_layer_norm_before
self
.
do_layer_norm_before
=
config
.
do_layer_norm_before
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
activation_fn
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
config
.
ffn_dim
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
,
self
.
embed_dim
,
...
@@ -142,6 +139,9 @@ class OPTDecoderLayer(nn.Module):
...
@@ -142,6 +139,9 @@ class OPTDecoderLayer(nn.Module):
bias
=
config
.
enable_bias
,
bias
=
config
.
enable_bias
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
activation_fn
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
config
.
ffn_dim
)
self
.
fc2
=
RowParallelLinear
(
self
.
fc2
=
RowParallelLinear
(
config
.
ffn_dim
,
config
.
ffn_dim
,
self
.
embed_dim
,
self
.
embed_dim
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment