Commit dc5df92f authored by thomwolf's avatar thomwolf
Browse files

added LM head for OpenAI

parent 3cf12b23
...@@ -5,7 +5,8 @@ from .modeling import (BertConfig, BertModel, BertForPreTraining, ...@@ -5,7 +5,8 @@ from .modeling import (BertConfig, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction, BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForMultipleChoice, BertForSequenceClassification, BertForMultipleChoice,
BertForTokenClassification, BertForQuestionAnswering) BertForTokenClassification, BertForQuestionAnswering)
from .modeling_openai import OpenAIGPTConfig, OpenAIGPTModel, OpenAIGPTDoubleHeadsModel from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
from .optimization import BertAdam from .optimization import BertAdam
from .optimization_openai import OpenAIAdam from .optimization_openai import OpenAIAdam
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE
...@@ -267,11 +267,11 @@ class OpenAIGPTMultipleChoiceHead(nn.Module): ...@@ -267,11 +267,11 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
nn.init.normal_(self.linear.weight, std = 0.02) nn.init.normal_(self.linear.weight, std = 0.02)
nn.init.normal_(self.linear.bias, 0) nn.init.normal_(self.linear.bias, 0)
def forward(self, hidden_states, classification_token_mask): def forward(self, hidden_states, multiple_choice_token_mask):
# Classification logits # Classification logits
# hidden_states = hidden_states.view(-1, self.n_embd) # hidden_states = hidden_states.view(-1, self.n_embd)
# classification_token_mask = classification_token_mask.view(-1, 1).expand_as(hidden_states) # multiple_choice_token_mask = multiple_choice_token_mask.view(-1, 1).expand_as(hidden_states)
multiple_choice_h = hidden_states * classification_token_mask.unsqueeze(-1) multiple_choice_h = hidden_states * multiple_choice_token_mask.unsqueeze(-1)
multiple_choice_h = multiple_choice_h.sum(dim=-2) multiple_choice_h = multiple_choice_h.sum(dim=-2)
# flat = x[..., 0].contiguous().view(-1) # flat = x[..., 0].contiguous().view(-1)
# multiple_choice_h = multiple_choice_h[flat == self.multiple_choice_token, :] # multiple_choice_h = multiple_choice_h[flat == self.multiple_choice_token, :]
...@@ -496,8 +496,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -496,8 +496,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids) hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
if lm_labels is not None: if lm_labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(lm_logits, lm_labels) loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))
return loss return loss
return lm_logits return lm_logits
...@@ -515,15 +515,14 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -515,15 +515,14 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self.transformer.set_num_special_tokens(num_special_tokens) self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.embed.weight) self.lm_head.set_embeddings_weights(self.transformer.embed.weight)
def forward(self, input_ids, classification_token_mask, position_ids=None, token_type_ids=None, def forward(self, input_ids, multiple_choice_token_mask, position_ids=None, token_type_ids=None,
lm_labels=None, multiple_choice_labels=None): lm_labels=None, multiple_choice_labels=None):
""" """ input_ids should be of shape B x C x S
input_ids as to be of shape B x C x S
lm_labels can be masked using the -1 value lm_labels can be masked using the -1 value
""" """
hidden_states = self.transformer(input_ids, position_ids, token_type_ids) hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
multiple_choice_logits = self.multiple_choice_head(hidden_states, classification_token_mask) multiple_choice_logits = self.multiple_choice_head(hidden_states, multiple_choice_token_mask)
losses = [] losses = []
if lm_labels is not None: if lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
......
...@@ -22,7 +22,8 @@ import random ...@@ -22,7 +22,8 @@ import random
import torch import torch
from pytorch_pretrained_bert import (OpenAIGPTConfig, OpenAIGPTModel, OpenAIGPTDoubleHeadsModel) from pytorch_pretrained_bert import (OpenAIGPTConfig, OpenAIGPTModel,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
class OpenAIGPTModelTest(unittest.TestCase): class OpenAIGPTModelTest(unittest.TestCase):
...@@ -89,11 +90,11 @@ class OpenAIGPTModelTest(unittest.TestCase): ...@@ -89,11 +90,11 @@ class OpenAIGPTModelTest(unittest.TestCase):
multiple_choice_labels = None multiple_choice_labels = None
lm_labels = None lm_labels = None
classification_token_mask = None multiple_choice_token_mask = None
if self.use_labels: if self.use_labels:
multiple_choice_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size) multiple_choice_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
lm_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels) lm_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels)
classification_token_mask = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], 2).float() multiple_choice_token_mask = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], 2).float()
config = OpenAIGPTConfig( config = OpenAIGPTConfig(
vocab_size_or_config_json_file=self.vocab_size, vocab_size_or_config_json_file=self.vocab_size,
...@@ -109,10 +110,10 @@ class OpenAIGPTModelTest(unittest.TestCase): ...@@ -109,10 +110,10 @@ class OpenAIGPTModelTest(unittest.TestCase):
initializer_range=self.initializer_range) initializer_range=self.initializer_range)
return (config, input_ids, token_type_ids, position_ids, return (config, input_ids, token_type_ids, position_ids,
multiple_choice_labels, lm_labels, classification_token_mask) multiple_choice_labels, lm_labels, multiple_choice_token_mask)
def create_openai_model(self, config, input_ids, token_type_ids, position_ids, def create_openai_model(self, config, input_ids, token_type_ids, position_ids,
multiple_choice_labels, lm_labels, classification_token_mask): multiple_choice_labels, lm_labels, multiple_choice_token_mask):
model = OpenAIGPTModel(config) model = OpenAIGPTModel(config)
hidden_states = model(input_ids, position_ids, token_type_ids) hidden_states = model(input_ids, position_ids, token_type_ids)
outputs = { outputs = {
...@@ -126,12 +127,34 @@ class OpenAIGPTModelTest(unittest.TestCase): ...@@ -126,12 +127,34 @@ class OpenAIGPTModelTest(unittest.TestCase):
[self.batch_size, self.n_choices, self.seq_length, self.n_embd]) [self.batch_size, self.n_choices, self.seq_length, self.n_embd])
def create_openai_lm_head(self, config, input_ids, token_type_ids, position_ids,
multiple_choice_labels, lm_labels, multiple_choice_token_mask):
model = OpenAIGPTLMHeadModel(config)
loss = model(input_ids, position_ids, token_type_ids, lm_labels)
lm_logits = model(input_ids, position_ids, token_type_ids)
outputs = {
"loss": loss,
"lm_logits": lm_logits,
}
return outputs
def check_openai_lm_head_output(self, result):
total_voc = self.n_ctx + self.n_special + self.vocab_size
self.parent.assertListEqual(
list(result["lm_logits"].size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc])
def check_openai_lm_head_loss_output(self, result):
self.parent.assertListEqual(
list(result["loss"].size()),
[])
def create_openai_double_heads(self, config, input_ids, token_type_ids, position_ids, def create_openai_double_heads(self, config, input_ids, token_type_ids, position_ids,
multiple_choice_labels, lm_labels, classification_token_mask): multiple_choice_labels, lm_labels, multiple_choice_token_mask):
model = OpenAIGPTDoubleHeadsModel(config) model = OpenAIGPTDoubleHeadsModel(config)
loss = model(input_ids, classification_token_mask, position_ids, loss = model(input_ids, multiple_choice_token_mask, position_ids,
token_type_ids, lm_labels, multiple_choice_labels) token_type_ids, lm_labels, multiple_choice_labels)
lm_logits, multiple_choice_logits = model(input_ids, classification_token_mask, position_ids, token_type_ids) lm_logits, multiple_choice_logits = model(input_ids, multiple_choice_token_mask, position_ids, token_type_ids)
outputs = { outputs = {
"loss": loss, "loss": loss,
"lm_logits": lm_logits, "lm_logits": lm_logits,
...@@ -167,6 +190,10 @@ class OpenAIGPTModelTest(unittest.TestCase): ...@@ -167,6 +190,10 @@ class OpenAIGPTModelTest(unittest.TestCase):
output_result = tester.create_openai_model(*config_and_inputs) output_result = tester.create_openai_model(*config_and_inputs)
tester.check_openai_model_output(output_result) tester.check_openai_model_output(output_result)
output_result = tester.create_openai_lm_head(*config_and_inputs)
tester.check_openai_lm_head_output(output_result)
tester.check_openai_lm_head_loss_output(output_result)
output_result = tester.create_openai_double_heads(*config_and_inputs) output_result = tester.create_openai_double_heads(*config_and_inputs)
tester.check_openai_double_heads_output(output_result) tester.check_openai_double_heads_output(output_result)
tester.check_openai_double_heads_loss_output(output_result) tester.check_openai_double_heads_loss_output(output_result)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment