Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c30139a0
Commit
c30139a0
authored
Apr 30, 2019
by
thomwolf
Browse files
add special tokens to gpt-2
parent
b832d5bb
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
62 additions
and
14 deletions
+62
-14
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+52
-8
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+3
-3
tests/modeling_gpt2_test.py
tests/modeling_gpt2_test.py
+7
-3
No files found.
pytorch_pretrained_bert/modeling_gpt2.py
View file @
c30139a0
...
...
@@ -107,6 +107,7 @@ class GPT2Config(object):
def
__init__
(
self
,
vocab_size_or_config_json_file
=
50257
,
n_special
=
0
,
n_positions
=
1024
,
n_ctx
=
1024
,
n_embd
=
768
,
...
...
@@ -119,6 +120,7 @@ class GPT2Config(object):
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...)
n_positions: Number of positional embeddings.
n_ctx: Size of the causal mask (usually same as n_positions).
n_embd: Dimensionality of the embeddings and hidden states.
...
...
@@ -137,6 +139,7 @@ class GPT2Config(object):
self
.
__dict__
[
key
]
=
value
elif
isinstance
(
vocab_size_or_config_json_file
,
int
):
self
.
vocab_size
=
vocab_size_or_config_json_file
self
.
n_special
=
n_special
self
.
n_ctx
=
n_ctx
self
.
n_positions
=
n_positions
self
.
n_embd
=
n_embd
...
...
@@ -150,6 +153,10 @@ class GPT2Config(object):
"or the path to a pretrained model config file (str)"
)
@
property
def
total_tokens_embeddings
(
self
):
return
self
.
vocab_size
+
self
.
n_special
@
classmethod
def
from_dict
(
cls
,
json_object
):
"""Constructs a `GPT2Config` from a Python dictionary of parameters."""
...
...
@@ -290,11 +297,12 @@ class GPT2LMHead(nn.Module):
def
__init__
(
self
,
model_embeddings_weights
,
config
):
super
(
GPT2LMHead
,
self
).
__init__
()
self
.
n_embd
=
config
.
n_embd
embed_shape
=
model_embeddings_weights
.
shape
self
.
decoder
=
nn
.
Linear
(
embed_shape
[
1
],
embed_shape
[
0
],
bias
=
False
)
self
.
set_embeddings_weights
(
model_embeddings_weights
)
def
set_embeddings_weights
(
self
,
model_embeddings_weights
):
embed_shape
=
model_embeddings_weights
.
shape
self
.
decoder
=
nn
.
Linear
(
embed_shape
[
1
],
embed_shape
[
0
],
bias
=
False
)
self
.
decoder
.
weight
=
model_embeddings_weights
# Tied weights
def
forward
(
self
,
hidden_state
):
...
...
@@ -345,7 +353,7 @@ class GPT2PreTrainedModel(nn.Module):
)
self
.
config
=
config
def
set_
tied
(
self
):
def
set_
num_special_tokens
(
self
,
num_special_tokens
):
pass
def
init_weights
(
self
,
module
):
...
...
@@ -475,14 +483,32 @@ class GPT2PreTrainedModel(nn.Module):
"Error(s) in loading state_dict for {}:
\n\t
{}"
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
))
)
# Make sure we are still sharing the output and input embeddings after loading weights
model
.
set_tied
()
# Add additional embeddings for special tokens if needed
# This step also make sure we are still sharing the output and input embeddings after loading weights
model
.
set_num_special_tokens
(
num_special_tokens
if
num_special_tokens
is
not
None
else
config
.
n_special
)
return
model
class
GPT2Model
(
GPT2PreTrainedModel
):
"""OpenAI GPT-2 model ("Language Models are Unsupervised Multitask Learners").
GPT-2 use a single embedding matrix to store the word and special embeddings.
Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]...
Special tokens need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
The embeddings are ordered as follow in the token embeddings matrice:
[0, ----------------------
... -> word embeddings
config.vocab_size - 1, ______________________
config.vocab_size,
... -> special embeddings
config.vocab_size + config.n_special - 1] ______________________
where total_tokens_embeddings can be obtained as config.total_tokens_embeddings and is:
total_tokens_embeddings = config.vocab_size + config.n_special
You should use the associate indices to index the embeddings.
Params:
config: a GPT2Config class instance with the configuration to build a new model
...
...
@@ -529,6 +555,20 @@ class GPT2Model(GPT2PreTrainedModel):
self
.
apply
(
self
.
init_weights
)
def
set_num_special_tokens
(
self
,
num_special_tokens
):
" Update input embeddings with new embedding matrice if needed "
if
self
.
config
.
n_special
==
num_special_tokens
:
return
# Update config
self
.
config
.
n_special
=
num_special_tokens
# Build new embeddings and initialize all new embeddings (in particular the special tokens)
old_embed
=
self
.
wte
self
.
wte
=
nn
.
Embedding
(
self
.
config
.
total_tokens_embeddings
,
self
.
config
.
n_embd
)
self
.
wte
.
to
(
old_embed
.
weight
.
device
)
self
.
init_weights
(
self
.
wte
)
# Copy word embeddings from the previous weights
self
.
wte
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
=
old_embed
.
weight
.
data
[:
self
.
config
.
vocab_size
,
:]
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
past
=
None
):
if
past
is
None
:
past_length
=
0
...
...
@@ -610,9 +650,11 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self
.
lm_head
=
GPT2LMHead
(
self
.
transformer
.
wte
.
weight
,
config
)
self
.
apply
(
self
.
init_weights
)
def
set_tied
(
self
):
""" Make sure we are sharing the embeddings
def
set_num_special_tokens
(
self
,
num_special_tokens
):
""" Update input and output embeddings with new embedding matrice
Make sure we are sharing the embeddings
"""
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
wte
.
weight
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
lm_labels
=
None
,
past
=
None
):
...
...
@@ -687,9 +729,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self
.
multiple_choice_head
=
GPT2MultipleChoiceHead
(
config
)
self
.
apply
(
self
.
init_weights
)
def
set_tied
(
self
):
""" Make sure we are sharing the embeddings
def
set_num_special_tokens
(
self
,
num_special_tokens
):
""" Update input and output embeddings with new embedding matrice
Make sure we are sharing the embeddings
"""
self
.
transformer
.
set_num_special_tokens
(
num_special_tokens
)
self
.
lm_head
.
set_embeddings_weights
(
self
.
transformer
.
wte
.
weight
)
def
forward
(
self
,
input_ids
,
mc_token_ids
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
past
=
None
):
...
...
pytorch_pretrained_bert/modeling_openai.py
View file @
c30139a0
...
...
@@ -344,11 +344,12 @@ class OpenAIGPTLMHead(nn.Module):
def
__init__
(
self
,
model_embeddings_weights
,
config
):
super
(
OpenAIGPTLMHead
,
self
).
__init__
()
self
.
n_embd
=
config
.
n_embd
embed_shape
=
model_embeddings_weights
.
shape
self
.
decoder
=
nn
.
Linear
(
embed_shape
[
1
],
embed_shape
[
0
],
bias
=
False
)
self
.
set_embeddings_weights
(
model_embeddings_weights
)
def
set_embeddings_weights
(
self
,
model_embeddings_weights
):
embed_shape
=
model_embeddings_weights
.
shape
self
.
decoder
=
nn
.
Linear
(
embed_shape
[
1
],
embed_shape
[
0
],
bias
=
False
)
self
.
decoder
.
weight
=
model_embeddings_weights
# Tied weights
def
forward
(
self
,
hidden_state
):
...
...
@@ -592,8 +593,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
def
__init__
(
self
,
config
):
super
(
OpenAIGPTModel
,
self
).
__init__
(
config
)
num_tokens
=
config
.
vocab_size
+
config
.
n_special
self
.
tokens_embed
=
nn
.
Embedding
(
num_tokens
,
config
.
n_embd
)
self
.
tokens_embed
=
nn
.
Embedding
(
config
.
total_tokens_embeddings
,
config
.
n_embd
)
self
.
positions_embed
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
block
=
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
...
...
tests/modeling_gpt2_test.py
View file @
c30139a0
...
...
@@ -41,6 +41,7 @@ class GPT2ModelTest(unittest.TestCase):
use_token_type_ids
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
n_special
=
1
,
n_positions
=
33
,
n_embd
=
32
,
n_layer
=
5
,
...
...
@@ -58,6 +59,7 @@ class GPT2ModelTest(unittest.TestCase):
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
n_special
=
n_special
self
.
n_positions
=
n_positions
self
.
n_embd
=
n_embd
self
.
n_layer
=
n_layer
...
...
@@ -69,7 +71,8 @@ class GPT2ModelTest(unittest.TestCase):
self
.
scope
=
scope
def
prepare_config_and_inputs
(
self
):
input_ids
=
GPT2ModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
self
.
vocab_size
)
total_num_tokens
=
self
.
vocab_size
+
self
.
n_special
input_ids
=
GPT2ModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
],
total_num_tokens
)
position_ids
=
None
if
self
.
use_position_ids
:
...
...
@@ -90,6 +93,7 @@ class GPT2ModelTest(unittest.TestCase):
config
=
GPT2Config
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
n_special
=
self
.
n_special
,
n_positions
=
self
.
n_positions
,
n_embd
=
self
.
n_embd
,
n_layer
=
self
.
n_layer
,
...
...
@@ -130,7 +134,7 @@ class GPT2ModelTest(unittest.TestCase):
return
outputs
def
check_gpt2_lm_head_output
(
self
,
result
):
total_voc
=
self
.
vocab_size
total_voc
=
self
.
n_special
+
self
.
vocab_size
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
...
...
@@ -157,7 +161,7 @@ class GPT2ModelTest(unittest.TestCase):
return
outputs
def
check_gpt2_double_heads_output
(
self
,
result
):
total_voc
=
self
.
vocab_size
total_voc
=
self
.
n_special
+
self
.
vocab_size
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
n_choices
,
self
.
seq_length
,
total_voc
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment