Commit 3a848111 authored by thomwolf's avatar thomwolf
Browse files

update config, docstrings and readme to switch to seperated tokens and position embeddings

parent 98c96fb1
......@@ -391,35 +391,36 @@ An example on how to use this class is given in the [`run_squad.py`](./examples/
`OpenAIGPTModel` is the basic OpenAI GPT Transformer model with a layer of summed token and position embeddings followed by a series of 12 identical self-attention blocks.
The main implementation difference between BERT and the OpenAI is the use, in OpenAI GPT, of a single embedding matrix to store the word, special (`[SEP]`, `[CLS]`...) token and position embeddings.
The embeddings are ordered as follow in the word embeddings matrice:
OpenAI GPT use a single embedding matrix to store the word and special embeddings.
Special tokens embeddings are additional tokens that are not pre-trained: `[SEP]`, `[CLS]`...
Special tokens need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
The embeddings are ordered as follow in the token embeddings matrice:
```python
[0, ----------------------
... -> word embeddings
config.vocab_size - 1, ______________________
config.vocab_size,
... -> special embeddings
config.vocab_size + config.n_special - 1, ______________________
config.vocab_size + config.n_special,
... -> position embeddings
total_num_embeddings - 1] ______________________
where total_num_embeddings can be obtained as config.total_num_embeddings and is:
config.vocab_size + config.n_special - 1] ______________________
```
total_num_embeddings = config.vocab_size + config.n_special + config.n_ctx
where total_tokens_embeddings can be obtained as config.total_tokens_embeddings and is:
`total_tokens_embeddings = config.vocab_size + config.n_special`
You should use the associate indices to index the embeddings.
The special tokens embeddings (`[SEP]`, `[CLS]`...) are not pre-trained and need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
The inputs and output are **identical to the TensorFlow model inputs and outputs**.
We detail them here. This model takes as *inputs*:
[`modeling_openai.py`](./pytorch_pretrained_bert/modeling_openai.py)
- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length] were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
- `position_ids`: an optional torch.LongTensor with the same shape as input_ids with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_ctx - 1[.
- `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids. You can use it to add a third embedding (the previous two being the word and position embeddings) to each token in the sentence.
- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length] were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, total_tokens_embeddings[
- `position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [0, config.n_positions - 1[.
- `token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third type of embedding to each input token in the sequence
(the previous two being the word and position embeddings). The input, position and token_type embeddings are summed inside the Transformer before the first self-attention block.
This model *outputs*:
- `hidden_states`: the encoded-hidden-states at the top of the model as a torch.FloatTensor of size [batch_size, sequence_length, hidden_size] (or more generally [d_1, ..., d_n, hidden_size] were d_1 ... d_n are the dimension of input_ids)
......@@ -435,7 +436,7 @@ This model *outputs*:
- if `lm_labels` is not `None`:
Outputs the language modeling loss.
- else:
Outputs `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, sequence_length, total_num_embeddings] (or more generally [d_1, ..., d_n, total_num_embeddings] were d_1 ... d_n are the dimension of input_ids)
Outputs `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, sequence_length, total_tokens_embeddings] (or more generally [d_1, ..., d_n, total_tokens_embeddings] were d_1 ... d_n are the dimension of input_ids)
#### 11. `OpenAIGPTDoubleHeadsModel`
......@@ -452,7 +453,7 @@ This model *outputs*:
- if `lm_labels` and `multiple_choice_labels` are not `None`:
Outputs a tuple of losses with the language modeling loss and the multiple choice loss.
- else Outputs a tuple with:
- `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, num_choices, sequence_length, total_num_embeddings]
- `lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, num_choices, sequence_length, total_tokens_embeddings]
- `multiple_choice_logits`: the multiple choice logits as a torch.FloatTensor of size [batch_size, num_choices]
......
......@@ -185,8 +185,8 @@ class OpenAIGPTConfig(object):
)
@property
def total_num_embeddings(self):
return self.vocab_size + self.n_special + self.n_positions
def total_tokens_embeddings(self):
return self.vocab_size + self.n_special
@classmethod
def from_dict(cls, json_object):
......@@ -533,45 +533,44 @@ class OpenAIGPTPreTrainedModel(nn.Module):
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
)
# Add additional embeddings for special tokens if needed
if num_special_tokens is not None and num_special_tokens != config.n_special:
model.set_num_special_tokens(num_special_tokens)
# This step also make sure we are still sharing the output and input embeddings after loading weights
model.set_num_special_tokens(num_special_tokens if num_special_tokens is not None else config.n_special)
return model
class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
"""OpenAI GPT model ("Improving Language Understanding by Generative Pre-Training").
The main implementation difference between BERT and the OpenAI is the use, in OpenAI GPT, of a single embedding matrix
to store the word, special ([SEP], [CLS]...) and position embeddings.
The embeddings are ordered as follow in the word embeddings matrice:
OpenAI GPT use a single embedding matrix to store the word and special embeddings.
Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]...
Special tokens need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
The embeddings are ordered as follow in the token embeddings matrice:
[0, ----------------------
... -> word embeddings
config.vocab_size - 1, ______________________
config.vocab_size,
... -> special embeddings
config.vocab_size + config.n_special - 1, ______________________
config.vocab_size + config.n_special,
... -> position embeddings
total_num_embeddings - 1] ______________________
config.vocab_size + config.n_special - 1] ______________________
where total_num_embeddings can be obtained as config.total_num_embeddings and is:
total_num_embeddings = config.vocab_size + config.n_special + config.n_positions
where total_tokens_embeddings can be obtained as config.total_tokens_embeddings and is:
total_tokens_embeddings = config.vocab_size + config.n_special
You should use the associate indices to index the embeddings.
The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
Params:
config: a OpenAIGPTConfig class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, total_tokens_embeddings[
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_positions - 1[.
with the position indices (selected in the range [0, config.n_positions - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third embedding (the previous two being the word and position embeddings)
to each token in the sentence.
You can use it to add a third type of embedding to each input token in the sequence
(the previous two being the word and position embeddings).
The input, position and token_type embeddings are summed inside the Transformer before the first
self-attention block.
Outputs:
`hidden_states`: the encoded-hidden-states at the top of the model
......@@ -603,12 +602,14 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
# nn.init.normal_(self.embed.weight, std=0.02)
def set_num_special_tokens(self, num_special_tokens):
" Update input embeddings with new embedding matrice "
" Update input embeddings with new embedding matrice if needed "
if self.config.n_special == num_special_tokens:
return
# Update config
self.config.n_special = num_special_tokens
# # Build new embeddings and initialize
old_embed = self.tokens_embed
self.tokens_embed = nn.Embedding(self.config.total_num_embeddings, self.config.n_embd)
self.tokens_embed = nn.Embedding(self.config.total_tokens_embeddings, self.config.n_embd)
# Initialize all new embeddings (in particular the special tokens)
self.init_weights(self.tokens_embed)
# Copy word and positional embeddings from the previous weights
......@@ -646,39 +647,36 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
"""OpenAI GPT model with a Language Modeling head ("Improving Language Understanding by Generative Pre-Training").
There are two main implementation differences between BERT and the OpenAI GPT:
- the use of an LM loss in OpenAI GPT which means the Transformer is trained to predict the NEXT token for each input token
vs. predict the SAME token for BERT (i.e. you need to shift your labels to the right)
- the use, in OpenAI GPT, of a single embedding matrix to store the word, special ([SEP], [CLS]...) and position embeddings.
The embeddings are ordered as follow in the word embeddings matrice:
OpenAI GPT use a single embedding matrix to store the word and special embeddings.
Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]...
Special tokens need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
The embeddings are ordered as follow in the token embeddings matrice:
[0, ----------------------
... -> word embeddings
config.vocab_size - 1, ______________________
config.vocab_size,
... -> special embeddings
config.vocab_size + config.n_special - 1, ______________________
config.vocab_size + config.n_special,
... -> position embeddings
total_num_embeddings - 1] ______________________
config.vocab_size + config.n_special - 1] ______________________
where total_num_embeddings can be obtained as config.total_num_embeddings and is:
total_num_embeddings = config.vocab_size + config.n_special + config.n_positions
You should use these indices to index the word, special and position embeddings.
The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
where total_tokens_embeddings can be obtained as config.total_tokens_embeddings and is:
total_tokens_embeddings = config.vocab_size + config.n_special
You should use the associate indices to index the embeddings.
Params:
config: a OpenAIGPTConfig class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, config.vocab_size[
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, total_tokens_embeddings[
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [config.vocab_size + config.n_special, config.vocab_size + config.n_special + config.n_positions - 1[.
with the position indices (selected in the range [0, config.n_positions - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third embedding (the previous two being the word and position embeddings)
to each token in the sentence.
You can use it to add a third type of embedding to each input token in the sequence
(the previous two being the word and position embeddings).
The input, position and token_type embeddings are summed inside the Transformer before the first
self-attention block.
`lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., vocab_size]
......@@ -687,8 +685,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
if `lm_labels` is not `None`:
Outputs the language modeling loss.
else:
`lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, sequence_length, total_num_embeddings]
(or more generally [d_1, ..., d_n, total_num_embeddings] were d_1 ... d_n are the dimension of input_ids)
`lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, sequence_length, total_tokens_embeddings]
(or more generally [d_1, ..., d_n, total_tokens_embeddings] were d_1 ... d_n are the dimension of input_ids)
Example usage:
```python
......@@ -726,45 +724,39 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
"""OpenAI GPT model with a Language Modeling and a Multiple Choice heads ("Improving Language Understanding by Generative Pre-Training").
There are two main implementation differences between BERT and the OpenAI GPT:
- the use of an LM loss in OpenAI GPT which means the Transformer is trained to predict the NEXT token for each input token
vs. predict the SAME token for BERT (i.e. you need to shift your labels to the right)
- the use, in OpenAI GPT, of a single embedding matrix to store the word, special ([SEP], [CLS]...) and position embeddings.
The embeddings are ordered as follow in the word embeddings matrice:
OpenAI GPT use a single embedding matrix to store the word and special embeddings.
Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]...
Special tokens need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
The embeddings are ordered as follow in the token embeddings matrice:
[0, ----------------------
... -> word embeddings
config.vocab_size - 1, ______________________
config.vocab_size,
... -> special embeddings
config.vocab_size + config.n_special - 1, ______________________
config.vocab_size + config.n_special,
... -> position embeddings
total_num_embeddings - 1] ______________________
where total_num_embeddings can be obtained as config.total_num_embeddings and is:
total_num_embeddings = config.vocab_size + config.n_special + config.n_positions
You should use these indices to index the word, special and position embeddings.
config.vocab_size + config.n_special - 1] ______________________
The special embeddings ([SEP], [CLS]...) are not pre-trained and need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
where total_tokens_embeddings can be obtained as config.total_tokens_embeddings and is:
total_tokens_embeddings = config.vocab_size + config.n_special
You should use the associate indices to index the embeddings.
Params:
config: a OpenAIGPTConfig class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with the word BPE token indices selected in the range [0, config.vocab_size[
`mc_token_mask`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with a value of 1 were the last hidden state is (usually the [CLS] token) and 0 otherwise.
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
were d_1 ... d_n are arbitrary dimensions) with the word BPE token indices selected in the range [0, total_tokens_embeddings[
`position_ids`: an optional torch.LongTensor with the same shape as input_ids
with the position indices (selected in the range [config.vocab_size + config.n_special,
config.vocab_size + config.n_special + config.n_positions - 1[.
with the position indices (selected in the range [0, config.n_positions - 1[.
`token_type_ids`: an optional torch.LongTensor with the same shape as input_ids
You can use it to add a third embedding (the previous two being the word and position embeddings)
to each token in the sentence.
You can use it to add a third type of embedding to each input token in the sequence
(the previous two being the word and position embeddings).
The input, position and token_type embeddings are summed inside the Transformer before the first
self-attention block.
`lm_labels`: optional language modeling labels: torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with indices selected in [-1, 0, ..., total_num_embeddings]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., total_num_embeddings]
with indices selected in [-1, 0, ..., total_tokens_embeddings]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., total_tokens_embeddings]
`multiple_choice_labels`: optional multiple choice labels: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_choices].
......@@ -772,7 +764,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
if `lm_labels` and `multiple_choice_labels` are not `None`:
Outputs a tuple of losses with the language modeling loss and the multiple choice loss.
else: a tuple with
`lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, num_choices, sequence_length, total_num_embeddings]
`lm_logits`: the language modeling logits as a torch.FloatTensor of size [batch_size, num_choices, sequence_length, total_tokens_embeddings]
`multiple_choice_logits`: the multiple choice logits as a torch.FloatTensor of size [batch_size, num_choices]
Example usage:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment