Commit 5f96ebc0 authored by Lysandre's avatar Lysandre Committed by Lysandre Debut
Browse files

Style

parent 950c6a4f
...@@ -52,35 +52,35 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -52,35 +52,35 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
class FlaubertModelTester(object): class FlaubertModelTester(object):
def __init__( def __init__(
self, self,
parent, parent,
batch_size=13, batch_size=13,
seq_length=7, seq_length=7,
is_training=True, is_training=True,
use_input_lengths=True, use_input_lengths=True,
use_token_type_ids=True, use_token_type_ids=True,
use_labels=True, use_labels=True,
gelu_activation=True, gelu_activation=True,
sinusoidal_embeddings=False, sinusoidal_embeddings=False,
causal=False, causal=False,
asm=False, asm=False,
n_langs=2, n_langs=2,
vocab_size=99, vocab_size=99,
n_special=0, n_special=0,
hidden_size=32, hidden_size=32,
num_hidden_layers=5, num_hidden_layers=5,
num_attention_heads=4, num_attention_heads=4,
hidden_dropout_prob=0.1, hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1, attention_probs_dropout_prob=0.1,
max_position_embeddings=512, max_position_embeddings=512,
type_vocab_size=16, type_vocab_size=16,
type_sequence_label_size=2, type_sequence_label_size=2,
initializer_range=0.02, initializer_range=0.02,
num_labels=3, num_labels=3,
num_choices=4, num_choices=4,
summary_type="last", summary_type="last",
use_proj=True, use_proj=True,
scope=None, scope=None,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -119,7 +119,7 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -119,7 +119,7 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
input_lengths = None input_lengths = None
if self.use_input_lengths: if self.use_input_lengths:
input_lengths = ( input_lengths = (
ids_tensor([self.batch_size], vocab_size=2) + self.seq_length - 2 ids_tensor([self.batch_size], vocab_size=2) + self.seq_length - 2
) # small variation of seq_length ) # small variation of seq_length
token_type_ids = None token_type_ids = None
...@@ -168,15 +168,15 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -168,15 +168,15 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(list(result["loss"].size()), []) self.parent.assertListEqual(list(result["loss"].size()), [])
def create_and_check_flaubert_model( def create_and_check_flaubert_model(
self, self,
config, config,
input_ids, input_ids,
token_type_ids, token_type_ids,
input_lengths, input_lengths,
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
input_mask, input_mask,
): ):
model = FlaubertModel(config=config) model = FlaubertModel(config=config)
model.to(torch_device) model.to(torch_device)
...@@ -193,15 +193,15 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -193,15 +193,15 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
) )
def create_and_check_flaubert_lm_head( def create_and_check_flaubert_lm_head(
self, self,
config, config,
input_ids, input_ids,
token_type_ids, token_type_ids,
input_lengths, input_lengths,
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
input_mask, input_mask,
): ):
model = FlaubertWithLMHeadModel(config) model = FlaubertWithLMHeadModel(config)
model.to(torch_device) model.to(torch_device)
...@@ -220,15 +220,15 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -220,15 +220,15 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
) )
def create_and_check_flaubert_simple_qa( def create_and_check_flaubert_simple_qa(
self, self,
config, config,
input_ids, input_ids,
token_type_ids, token_type_ids,
input_lengths, input_lengths,
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
input_mask, input_mask,
): ):
model = FlaubertForQuestionAnsweringSimple(config) model = FlaubertForQuestionAnsweringSimple(config)
model.to(torch_device) model.to(torch_device)
...@@ -249,15 +249,15 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -249,15 +249,15 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_flaubert_qa( def create_and_check_flaubert_qa(
self, self,
config, config,
input_ids, input_ids,
token_type_ids, token_type_ids,
input_lengths, input_lengths,
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
input_mask, input_mask,
): ):
model = FlaubertForQuestionAnswering(config) model = FlaubertForQuestionAnswering(config)
model.to(torch_device) model.to(torch_device)
...@@ -316,15 +316,15 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -316,15 +316,15 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(list(result["cls_logits"].size()), [self.batch_size]) self.parent.assertListEqual(list(result["cls_logits"].size()), [self.batch_size])
def create_and_check_flaubert_sequence_classif( def create_and_check_flaubert_sequence_classif(
self, self,
config, config,
input_ids, input_ids,
token_type_ids, token_type_ids,
input_lengths, input_lengths,
sequence_labels, sequence_labels,
token_labels, token_labels,
is_impossible_labels, is_impossible_labels,
input_mask, input_mask,
): ):
model = FlaubertForSequenceClassification(config) model = FlaubertForSequenceClassification(config)
model.to(torch_device) model.to(torch_device)
......
...@@ -185,7 +185,7 @@ class RobertaModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -185,7 +185,7 @@ class RobertaModelTest(ModelTesterMixin, unittest.TestCase):
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_roberta_for_multiple_choice( def create_and_check_roberta_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
config.num_choices = self.num_choices config.num_choices = self.num_choices
model = RobertaForMultipleChoice(config=config) model = RobertaForMultipleChoice(config=config)
...@@ -208,7 +208,7 @@ class RobertaModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -208,7 +208,7 @@ class RobertaModelTest(ModelTesterMixin, unittest.TestCase):
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_roberta_for_question_answering( def create_and_check_roberta_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
): ):
model = RobertaForQuestionAnswering(config=config) model = RobertaForQuestionAnswering(config=config)
model.to(torch_device) model.to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment