Commit a45a9cc0 authored by thomwolf's avatar thomwolf
Browse files

update tests

parent b12616fd
...@@ -88,13 +88,13 @@ class OpenAIGPTModelTest(unittest.TestCase): ...@@ -88,13 +88,13 @@ class OpenAIGPTModelTest(unittest.TestCase):
total_voc = self.n_ctx + self.n_special + self.vocab_size total_voc = self.n_ctx + self.n_special + self.vocab_size
token_type_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_voc) token_type_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_voc)
multiple_choice_labels = None mc_labels = None
lm_labels = None lm_labels = None
multiple_choice_token_mask = None mc_token_mask = None
if self.use_labels: if self.use_labels:
multiple_choice_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size) mc_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
lm_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels) lm_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels)
multiple_choice_token_mask = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], 2).float() mc_token_mask = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], 2).float()
config = OpenAIGPTConfig( config = OpenAIGPTConfig(
vocab_size_or_config_json_file=self.vocab_size, vocab_size_or_config_json_file=self.vocab_size,
...@@ -110,10 +110,10 @@ class OpenAIGPTModelTest(unittest.TestCase): ...@@ -110,10 +110,10 @@ class OpenAIGPTModelTest(unittest.TestCase):
initializer_range=self.initializer_range) initializer_range=self.initializer_range)
return (config, input_ids, token_type_ids, position_ids, return (config, input_ids, token_type_ids, position_ids,
multiple_choice_labels, lm_labels, multiple_choice_token_mask) mc_labels, lm_labels, mc_token_mask)
def create_openai_model(self, config, input_ids, token_type_ids, position_ids, def create_openai_model(self, config, input_ids, token_type_ids, position_ids,
multiple_choice_labels, lm_labels, multiple_choice_token_mask): mc_labels, lm_labels, mc_token_mask):
model = OpenAIGPTModel(config) model = OpenAIGPTModel(config)
hidden_states = model(input_ids, position_ids, token_type_ids) hidden_states = model(input_ids, position_ids, token_type_ids)
outputs = { outputs = {
...@@ -128,7 +128,7 @@ class OpenAIGPTModelTest(unittest.TestCase): ...@@ -128,7 +128,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
def create_openai_lm_head(self, config, input_ids, token_type_ids, position_ids, def create_openai_lm_head(self, config, input_ids, token_type_ids, position_ids,
multiple_choice_labels, lm_labels, multiple_choice_token_mask): mc_labels, lm_labels, mc_token_mask):
model = OpenAIGPTLMHeadModel(config) model = OpenAIGPTLMHeadModel(config)
loss = model(input_ids, position_ids, token_type_ids, lm_labels) loss = model(input_ids, position_ids, token_type_ids, lm_labels)
lm_logits = model(input_ids, position_ids, token_type_ids) lm_logits = model(input_ids, position_ids, token_type_ids)
...@@ -150,15 +150,16 @@ class OpenAIGPTModelTest(unittest.TestCase): ...@@ -150,15 +150,16 @@ class OpenAIGPTModelTest(unittest.TestCase):
[]) [])
def create_openai_double_heads(self, config, input_ids, token_type_ids, position_ids, def create_openai_double_heads(self, config, input_ids, token_type_ids, position_ids,
multiple_choice_labels, lm_labels, multiple_choice_token_mask): mc_labels, lm_labels, mc_token_mask):
model = OpenAIGPTDoubleHeadsModel(config) model = OpenAIGPTDoubleHeadsModel(config)
loss = model(input_ids, multiple_choice_token_mask, position_ids, loss = model(input_ids, mc_token_mask,
token_type_ids, lm_labels, multiple_choice_labels) lm_labels=lm_labels, mc_labels=mc_labels,
lm_logits, multiple_choice_logits = model(input_ids, multiple_choice_token_mask, position_ids, token_type_ids) token_type_ids=token_type_ids, position_ids=position_ids)
lm_logits, mc_logits = model(input_ids, mc_token_mask, position_ids=position_ids, token_type_ids=token_type_ids)
outputs = { outputs = {
"loss": loss, "loss": loss,
"lm_logits": lm_logits, "lm_logits": lm_logits,
"multiple_choice_logits": multiple_choice_logits, "mc_logits": mc_logits,
} }
return outputs return outputs
...@@ -168,7 +169,7 @@ class OpenAIGPTModelTest(unittest.TestCase): ...@@ -168,7 +169,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
list(result["lm_logits"].size()), list(result["lm_logits"].size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc]) [self.batch_size, self.n_choices, self.seq_length, total_voc])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["multiple_choice_logits"].size()), list(result["mc_logits"].size()),
[self.batch_size, self.n_choices]) [self.batch_size, self.n_choices])
def check_openai_double_heads_loss_output(self, result): def check_openai_double_heads_loss_output(self, result):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment