Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
4a151dd4
Unverified
Commit
4a151dd4
authored
May 25, 2023
by
Woosuk Kwon
Committed by
GitHub
May 25, 2023
Browse files
Add activation registry (#126)
parent
057daef7
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
22 additions
and
13 deletions
+22
-13
cacheflow/entrypoints/llm.py
cacheflow/entrypoints/llm.py
+1
-1
cacheflow/model_executor/layers/activation.py
cacheflow/model_executor/layers/activation.py
+15
-0
cacheflow/model_executor/models/gpt2.py
cacheflow/model_executor/models/gpt2.py
+2
-6
cacheflow/model_executor/models/gpt_neox.py
cacheflow/model_executor/models/gpt_neox.py
+2
-4
cacheflow/model_executor/models/opt.py
cacheflow/model_executor/models/opt.py
+2
-2
No files found.
cacheflow/entrypoints/llm.py
View file @
4a151dd4
...
...
@@ -61,7 +61,7 @@ class LLM:
while
self
.
llm_server
.
has_unfinished_requests
():
step_outputs
=
self
.
llm_server
.
step
()
for
output
in
step_outputs
:
if
output
.
done
:
if
output
.
finished
()
:
outputs
.
append
(
output
)
if
use_tqdm
:
pbar
.
update
(
1
)
...
...
cacheflow/model_executor/layers/activation.py
View file @
4a151dd4
...
...
@@ -4,6 +4,21 @@ import torch.nn as nn
from
cacheflow
import
activation_ops
_ACTIVATION_REGISTRY
=
{
"gelu"
:
nn
.
GELU
(),
"gelu_new"
:
nn
.
GELU
(
approximate
=
"tanh"
),
# NOTE: This may introduce small rounding errors.
"gelu_fast"
:
nn
.
GELU
(
approximate
=
"tanh"
),
# NOTE: This may introduce small rounding errors.
"relu"
:
nn
.
ReLU
(),
}
def
get_act_fn
(
act_fn
:
str
)
->
nn
.
Module
:
"""Get an activation function by name."""
act_fn
=
act_fn
.
lower
()
if
act_fn
in
_ACTIVATION_REGISTRY
:
return
_ACTIVATION_REGISTRY
[
act_fn
]
raise
ValueError
(
f
"Activation function
{
act_fn
!
r
}
is not supported."
)
class
SiluAndMul
(
nn
.
Module
):
"""An activation function for SwiGLU.
...
...
cacheflow/model_executor/models/gpt2.py
View file @
4a151dd4
...
...
@@ -27,6 +27,7 @@ from torch import nn
from
transformers
import
GPT2Config
from
cacheflow.model_executor.input_metadata
import
InputMetadata
from
cacheflow.model_executor.layers.activation
import
get_act_fn
from
cacheflow.model_executor.layers.attention
import
GPTCacheFlowAttention
from
cacheflow.model_executor.layers.sampler
import
Sampler
from
cacheflow.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
...
...
@@ -92,12 +93,7 @@ class GPT2MLP(nn.Module):
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
act_fn
=
config
.
activation_function
if
act_fn
!=
"gelu_new"
:
raise
ValueError
(
f
"Unsupported activation:
{
act_fn
}
. "
"GPT-2 only supports gelu_new for now."
)
self
.
act
=
torch
.
nn
.
GELU
(
approximate
=
"tanh"
)
self
.
act
=
get_act_fn
(
config
.
activation_function
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
c_fc
(
hidden_states
)
...
...
cacheflow/model_executor/models/gpt_neox.py
View file @
4a151dd4
...
...
@@ -26,6 +26,7 @@ from torch import nn
from
transformers
import
GPTNeoXConfig
from
cacheflow.model_executor.input_metadata
import
InputMetadata
from
cacheflow.model_executor.layers.activation
import
get_act_fn
from
cacheflow.model_executor.layers.attention
import
GPTNeoXCacheFlowAttention
from
cacheflow.model_executor.layers.sampler
import
Sampler
from
cacheflow.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
...
...
@@ -94,10 +95,7 @@ class GPTNeoXMLP(nn.Module):
self
.
dense_4h_to_h
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
if
config
.
hidden_act
!=
'gelu'
:
raise
ValueError
(
f
'Unsupported activation:
{
config
.
hidden_act
}
. '
'Only gelu is supported for now.'
)
self
.
act
=
torch
.
nn
.
GELU
()
self
.
act
=
get_act_fn
(
config
.
hidden_act
)
def
forward
(
self
,
hidden_states
):
hidden_states
,
_
=
self
.
dense_h_to_4h
(
hidden_states
)
...
...
cacheflow/model_executor/models/opt.py
View file @
4a151dd4
...
...
@@ -26,6 +26,7 @@ from torch import nn
from
transformers
import
OPTConfig
from
cacheflow.model_executor.input_metadata
import
InputMetadata
from
cacheflow.model_executor.layers.activation
import
get_act_fn
from
cacheflow.model_executor.layers.attention
import
GPTCacheFlowAttention
from
cacheflow.model_executor.layers.sampler
import
Sampler
from
cacheflow.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
...
...
@@ -105,8 +106,7 @@ class OPTDecoderLayer(nn.Module):
bias
=
config
.
enable_bias
,
)
self
.
do_layer_norm_before
=
config
.
do_layer_norm_before
assert
config
.
activation_function
==
'relu'
self
.
activation_fn
=
nn
.
ReLU
()
self
.
activation_fn
=
get_act_fn
(
config
.
activation_function
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
,
elementwise_affine
=
config
.
layer_norm_elementwise_affine
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment