Commit c30139a0 authored by thomwolf's avatar thomwolf
Browse files

add special tokens to gpt-2

parent b832d5bb
...@@ -107,6 +107,7 @@ class GPT2Config(object): ...@@ -107,6 +107,7 @@ class GPT2Config(object):
def __init__( def __init__(
self, self,
vocab_size_or_config_json_file=50257, vocab_size_or_config_json_file=50257,
n_special=0,
n_positions=1024, n_positions=1024,
n_ctx=1024, n_ctx=1024,
n_embd=768, n_embd=768,
...@@ -119,6 +120,7 @@ class GPT2Config(object): ...@@ -119,6 +120,7 @@ class GPT2Config(object):
Args: Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file. vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...)
n_positions: Number of positional embeddings. n_positions: Number of positional embeddings.
n_ctx: Size of the causal mask (usually same as n_positions). n_ctx: Size of the causal mask (usually same as n_positions).
n_embd: Dimensionality of the embeddings and hidden states. n_embd: Dimensionality of the embeddings and hidden states.
...@@ -137,6 +139,7 @@ class GPT2Config(object): ...@@ -137,6 +139,7 @@ class GPT2Config(object):
self.__dict__[key] = value self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int): elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file self.vocab_size = vocab_size_or_config_json_file
self.n_special = n_special
self.n_ctx = n_ctx self.n_ctx = n_ctx
self.n_positions = n_positions self.n_positions = n_positions
self.n_embd = n_embd self.n_embd = n_embd
...@@ -150,6 +153,10 @@ class GPT2Config(object): ...@@ -150,6 +153,10 @@ class GPT2Config(object):
"or the path to a pretrained model config file (str)" "or the path to a pretrained model config file (str)"
) )
@property
def total_tokens_embeddings(self):
return self.vocab_size + self.n_special
@classmethod @classmethod
def from_dict(cls, json_object): def from_dict(cls, json_object):
"""Constructs a `GPT2Config` from a Python dictionary of parameters.""" """Constructs a `GPT2Config` from a Python dictionary of parameters."""
...@@ -290,11 +297,12 @@ class GPT2LMHead(nn.Module): ...@@ -290,11 +297,12 @@ class GPT2LMHead(nn.Module):
def __init__(self, model_embeddings_weights, config): def __init__(self, model_embeddings_weights, config):
super(GPT2LMHead, self).__init__() super(GPT2LMHead, self).__init__()
self.n_embd = config.n_embd self.n_embd = config.n_embd
embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.set_embeddings_weights(model_embeddings_weights) self.set_embeddings_weights(model_embeddings_weights)
def set_embeddings_weights(self, model_embeddings_weights): def set_embeddings_weights(self, model_embeddings_weights):
embed_shape = model_embeddings_weights.shape embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.decoder.weight = model_embeddings_weights # Tied weights self.decoder.weight = model_embeddings_weights # Tied weights
def forward(self, hidden_state): def forward(self, hidden_state):
...@@ -345,7 +353,7 @@ class GPT2PreTrainedModel(nn.Module): ...@@ -345,7 +353,7 @@ class GPT2PreTrainedModel(nn.Module):
) )
self.config = config self.config = config
def set_tied(self): def set_num_special_tokens(self, num_special_tokens):
pass pass
def init_weights(self, module): def init_weights(self, module):
...@@ -475,14 +483,32 @@ class GPT2PreTrainedModel(nn.Module): ...@@ -475,14 +483,32 @@ class GPT2PreTrainedModel(nn.Module):
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)) "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
) )
# Make sure we are still sharing the output and input embeddings after loading weights # Add additional embeddings for special tokens if needed
model.set_tied() # This step also make sure we are still sharing the output and input embeddings after loading weights
model.set_num_special_tokens(num_special_tokens if num_special_tokens is not None else config.n_special)
return model return model
class GPT2Model(GPT2PreTrainedModel): class GPT2Model(GPT2PreTrainedModel):
"""OpenAI GPT-2 model ("Language Models are Unsupervised Multitask Learners"). """OpenAI GPT-2 model ("Language Models are Unsupervised Multitask Learners").
GPT-2 use a single embedding matrix to store the word and special embeddings.
Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]...
Special tokens need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
The embeddings are ordered as follow in the token embeddings matrice:
[0, ----------------------
... -> word embeddings
config.vocab_size - 1, ______________________
config.vocab_size,
... -> special embeddings
config.vocab_size + config.n_special - 1] ______________________
where total_tokens_embeddings can be obtained as config.total_tokens_embeddings and is:
total_tokens_embeddings = config.vocab_size + config.n_special
You should use the associate indices to index the embeddings.
Params: Params:
config: a GPT2Config class instance with the configuration to build a new model config: a GPT2Config class instance with the configuration to build a new model
...@@ -529,6 +555,20 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -529,6 +555,20 @@ class GPT2Model(GPT2PreTrainedModel):
self.apply(self.init_weights) self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens):
" Update input embeddings with new embedding matrice if needed "
if self.config.n_special == num_special_tokens:
return
# Update config
self.config.n_special = num_special_tokens
# Build new embeddings and initialize all new embeddings (in particular the special tokens)
old_embed = self.wte
self.wte = nn.Embedding(self.config.total_tokens_embeddings, self.config.n_embd)
self.wte.to(old_embed.weight.device)
self.init_weights(self.wte)
# Copy word embeddings from the previous weights
self.wte.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None): def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
if past is None: if past is None:
past_length = 0 past_length = 0
...@@ -610,9 +650,11 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -610,9 +650,11 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
self.apply(self.init_weights) self.apply(self.init_weights)
def set_tied(self): def set_num_special_tokens(self, num_special_tokens):
""" Make sure we are sharing the embeddings """ Update input and output embeddings with new embedding matrice
Make sure we are sharing the embeddings
""" """
self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.wte.weight) self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None): def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
...@@ -687,9 +729,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -687,9 +729,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self.multiple_choice_head = GPT2MultipleChoiceHead(config) self.multiple_choice_head = GPT2MultipleChoiceHead(config)
self.apply(self.init_weights) self.apply(self.init_weights)
def set_tied(self): def set_num_special_tokens(self, num_special_tokens):
""" Make sure we are sharing the embeddings """ Update input and output embeddings with new embedding matrice
Make sure we are sharing the embeddings
""" """
self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.wte.weight) self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None): def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None):
......
...@@ -344,11 +344,12 @@ class OpenAIGPTLMHead(nn.Module): ...@@ -344,11 +344,12 @@ class OpenAIGPTLMHead(nn.Module):
def __init__(self, model_embeddings_weights, config): def __init__(self, model_embeddings_weights, config):
super(OpenAIGPTLMHead, self).__init__() super(OpenAIGPTLMHead, self).__init__()
self.n_embd = config.n_embd self.n_embd = config.n_embd
embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.set_embeddings_weights(model_embeddings_weights) self.set_embeddings_weights(model_embeddings_weights)
def set_embeddings_weights(self, model_embeddings_weights): def set_embeddings_weights(self, model_embeddings_weights):
embed_shape = model_embeddings_weights.shape embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.decoder.weight = model_embeddings_weights # Tied weights self.decoder.weight = model_embeddings_weights # Tied weights
def forward(self, hidden_state): def forward(self, hidden_state):
...@@ -592,8 +593,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -592,8 +593,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super(OpenAIGPTModel, self).__init__(config) super(OpenAIGPTModel, self).__init__(config)
num_tokens = config.vocab_size + config.n_special self.tokens_embed = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
self.tokens_embed = nn.Embedding(num_tokens, config.n_embd)
self.positions_embed = nn.Embedding(config.n_positions, config.n_embd) self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop) self.drop = nn.Dropout(config.embd_pdrop)
block = Block(config.n_ctx, config, scale=True) block = Block(config.n_ctx, config, scale=True)
......
...@@ -41,6 +41,7 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -41,6 +41,7 @@ class GPT2ModelTest(unittest.TestCase):
use_token_type_ids=True, use_token_type_ids=True,
use_labels=True, use_labels=True,
vocab_size=99, vocab_size=99,
n_special=1,
n_positions=33, n_positions=33,
n_embd=32, n_embd=32,
n_layer=5, n_layer=5,
...@@ -58,6 +59,7 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -58,6 +59,7 @@ class GPT2ModelTest(unittest.TestCase):
self.use_token_type_ids = use_token_type_ids self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels self.use_labels = use_labels
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.n_special = n_special
self.n_positions = n_positions self.n_positions = n_positions
self.n_embd = n_embd self.n_embd = n_embd
self.n_layer = n_layer self.n_layer = n_layer
...@@ -69,7 +71,8 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -69,7 +71,8 @@ class GPT2ModelTest(unittest.TestCase):
self.scope = scope self.scope = scope
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.vocab_size) total_num_tokens = self.vocab_size + self.n_special
input_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_num_tokens)
position_ids = None position_ids = None
if self.use_position_ids: if self.use_position_ids:
...@@ -90,6 +93,7 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -90,6 +93,7 @@ class GPT2ModelTest(unittest.TestCase):
config = GPT2Config( config = GPT2Config(
vocab_size_or_config_json_file=self.vocab_size, vocab_size_or_config_json_file=self.vocab_size,
n_special=self.n_special,
n_positions=self.n_positions, n_positions=self.n_positions,
n_embd=self.n_embd, n_embd=self.n_embd,
n_layer=self.n_layer, n_layer=self.n_layer,
...@@ -130,7 +134,7 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -130,7 +134,7 @@ class GPT2ModelTest(unittest.TestCase):
return outputs return outputs
def check_gpt2_lm_head_output(self, result): def check_gpt2_lm_head_output(self, result):
total_voc = self.vocab_size total_voc = self.n_special + self.vocab_size
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["lm_logits"].size()), list(result["lm_logits"].size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc]) [self.batch_size, self.n_choices, self.seq_length, total_voc])
...@@ -157,7 +161,7 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -157,7 +161,7 @@ class GPT2ModelTest(unittest.TestCase):
return outputs return outputs
def check_gpt2_double_heads_output(self, result): def check_gpt2_double_heads_output(self, result):
total_voc = self.vocab_size total_voc = self.n_special + self.vocab_size
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["lm_logits"].size()), list(result["lm_logits"].size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc]) [self.batch_size, self.n_choices, self.seq_length, total_voc])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment