Commit 9bdcba53 authored by thomwolf's avatar thomwolf
Browse files

fix tests

parent f0bf81e1
......@@ -93,7 +93,7 @@ class OpenAIGPTModelTest(unittest.TestCase):
if self.use_labels:
mc_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
lm_labels = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels)
mc_token_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices], self.seq_length).float()
mc_token_ids = OpenAIGPTModelTest.ids_tensor([self.batch_size, self.n_choices], self.seq_length)
config = OpenAIGPTConfig(
vocab_size_or_config_json_file=self.vocab_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment