Commit 33c01368 authored by Rémi Louf's avatar Rémi Louf
Browse files

remove Bert2Rnd test

parent 07520696
...@@ -29,7 +29,7 @@ if is_torch_available(): ...@@ -29,7 +29,7 @@ if is_torch_available():
from transformers import (BertConfig, BertModel, BertForMaskedLM, from transformers import (BertConfig, BertModel, BertForMaskedLM,
BertForNextSentencePrediction, BertForPreTraining, BertForNextSentencePrediction, BertForPreTraining,
BertForQuestionAnswering, BertForSequenceClassification, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification, BertForMultipleChoice, Bert2Rnd) BertForTokenClassification, BertForMultipleChoice)
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
else: else:
pytestmark = pytest.mark.skip("Require Torch") pytestmark = pytest.mark.skip("Require Torch")
...@@ -255,17 +255,6 @@ class BertModelTest(CommonTestCases.CommonModelTester): ...@@ -255,17 +255,6 @@ class BertModelTest(CommonTestCases.CommonModelTester):
[self.batch_size, self.num_choices]) [self.batch_size, self.num_choices])
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_bert2rnd(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
config.num_choices = self.num_choices
model = Bert2Rnd(config=config)
model.eval()
bert2rnd_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
bert2rnd_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
bert2rnd_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
_ = model(bert2rnd_inputs_ids,
attention_mask=bert2rnd_input_mask,
token_type_ids=bert2rnd_token_type_ids)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, token_type_ids, input_mask, (config, input_ids, token_type_ids, input_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment