Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
39161c98
Commit
39161c98
authored
Feb 09, 2023
by
Woosuk Kwon
Browse files
Add OPT
parent
e7d9d9c0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
207 additions
and
0 deletions
+207
-0
cacheflow/worker/models/__init__.py
cacheflow/worker/models/__init__.py
+5
-0
cacheflow/worker/models/opt.py
cacheflow/worker/models/opt.py
+202
-0
No files found.
cacheflow/worker/models/__init__.py
0 → 100644
View file @
39161c98
from
cacheflow.worker.models.opt
import
OPTForCausalLM
__all__
=
[
'OPTForCausalLM'
,
]
cacheflow/worker/models/opt.py
0 → 100644
View file @
39161c98
"""1D OPT model compatible with HuggingFace weights."""
import
torch
from
torch
import
nn
from
transformers
import
OPTConfig
from
transformers
import
PreTrainedModel
class
OPTLearnedPositionalEmbedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self
.
offset
=
2
super
().
__init__
(
num_embeddings
+
self
.
offset
,
embedding_dim
)
def
forward
(
self
,
positions
:
torch
.
LongTensor
):
return
super
().
forward
(
positions
+
self
.
offset
)
class
OPTAttention
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
:
int
,
num_heads
:
int
,
bias
:
bool
=
True
,
)
->
None
:
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
head_dim
=
embed_dim
//
num_heads
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
k_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
v_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
q_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
q
=
self
.
q_proj
(
hidden_states
)
*
self
.
scaling
k
=
self
.
k_proj
(
hidden_states
)
v
=
self
.
v_proj
(
hidden_states
)
# TODO
attn_output
=
None
output
=
self
.
out_proj
(
attn_output
)
return
output
class
OPTDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
OPTConfig
):
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
self
.
self_attn
=
OPTAttention
(
embed_dim
=
self
.
embed_dim
,
num_heads
=
config
.
num_attention_heads
,
bias
=
config
.
enable_bias
,
)
self
.
do_layer_norm_before
=
config
.
do_layer_norm_before
assert
config
.
activation_function
==
'relu'
self
.
activation_fn
=
nn
.
ReLU
()
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
,
elementwise_affine
=
config
.
layer_norm_elementwise_affine
)
self
.
fc1
=
nn
.
Linear
(
self
.
embed_dim
,
config
.
ffn_dim
,
bias
=
config
.
enable_bias
)
self
.
fc2
=
nn
.
Linear
(
config
.
ffn_dim
,
self
.
embed_dim
,
bias
=
config
.
enable_bias
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
,
elementwise_affine
=
config
.
layer_norm_elementwise_affine
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# Self Attention
residual
=
hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if
self
.
do_layer_norm_before
:
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
)
hidden_states
=
residual
+
hidden_states
# 350m applies layer norm AFTER attention
if
not
self
.
do_layer_norm_before
:
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
# Fully Connected
residual
=
hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if
self
.
do_layer_norm_before
:
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
hidden_states
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
=
self
.
fc2
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
# 350m applies layer norm AFTER attention
if
not
self
.
do_layer_norm_before
:
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
return
hidden_states
class
OPTPreTrainedModel
(
PreTrainedModel
):
config_class
=
OPTConfig
base_model_prefix
=
"model"
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
"OPTDecoderLayer"
]
_keys_to_ignore_on_load_unexpected
=
[
r
"decoder\.version"
]
def
_init_weights
(
self
,
module
)
->
None
:
del
module
# unused
return
class
OPTDecoder
(
OPTPreTrainedModel
):
def
__init__
(
self
,
config
:
OPTConfig
):
super
().
__init__
(
config
)
self
.
padding_idx
=
config
.
pad_token_id
self
.
max_target_positions
=
config
.
max_position_embeddings
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
word_embed_proj_dim
,
self
.
padding_idx
)
self
.
embed_positions
=
OPTLearnedPositionalEmbedding
(
config
.
max_position_embeddings
,
config
.
hidden_size
)
if
config
.
word_embed_proj_dim
!=
config
.
hidden_size
:
self
.
project_out
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
word_embed_proj_dim
,
bias
=
False
)
else
:
self
.
project_out
=
None
if
config
.
word_embed_proj_dim
!=
config
.
hidden_size
:
self
.
project_in
=
nn
.
Linear
(
config
.
word_embed_proj_dim
,
config
.
hidden_size
,
bias
=
False
)
else
:
self
.
project_in
=
None
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if
config
.
do_layer_norm_before
and
not
config
.
_remove_final_layer_norm
:
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
elementwise_affine
=
config
.
layer_norm_elementwise_affine
)
else
:
self
.
final_layer_norm
=
None
self
.
layers
=
nn
.
ModuleList
([
OPTDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
# Initialize weights and apply final processing
self
.
post_init
()
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
LongTensor
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
pos_embeds
=
self
.
embed_positions
(
positions
)
pos_embeds
=
None
if
self
.
project_in
is
not
None
:
inputs_embeds
=
self
.
project_in
(
inputs_embeds
)
hidden_states
=
inputs_embeds
+
pos_embeds
for
layer
in
self
.
layers
:
hidden_states
=
layer
(
hidden_states
)
if
self
.
final_layer_norm
is
not
None
:
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
if
self
.
project_out
is
not
None
:
hidden_states
=
self
.
project_out
(
hidden_states
)
return
hidden_states
class
OPTModel
(
OPTPreTrainedModel
):
def
__init__
(
self
,
config
:
OPTConfig
):
super
().
__init__
(
config
)
self
.
decoder
=
OPTDecoder
(
config
)
# Initialize weights and apply final processing
self
.
post_init
()
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
LongTensor
,
)
->
torch
.
Tensor
:
return
self
.
decoder
(
input_ids
,
positions
)
class
OPTForCausalLM
(
OPTPreTrainedModel
):
_keys_to_ignore_on_load_missing
=
[
r
"lm_head.weight"
]
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
model
=
OPTModel
(
config
)
# the lm_head weight is automatically tied to the embed tokens weight
self
.
lm_head
=
nn
.
Linear
(
config
.
word_embed_proj_dim
,
config
.
vocab_size
,
bias
=
False
)
# Initialize weights and apply final processing
self
.
post_init
()
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
LongTensor
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
.
decoder
(
input_ids
,
positions
)
logits
=
self
.
lm_head
(
hidden_states
).
contiguous
()
return
logits
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment