Commit 8abfee9e authored by Rémi Louf's avatar Rémi Louf
Browse files

rename Bert2Bert -> Bert2Rnd

parent 82628b0f
...@@ -64,7 +64,7 @@ if is_torch_available(): ...@@ -64,7 +64,7 @@ if is_torch_available():
BertForMaskedLM, BertForNextSentencePrediction, BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForMultipleChoice, BertForSequenceClassification, BertForMultipleChoice,
BertForTokenClassification, BertForQuestionAnswering, BertForTokenClassification, BertForQuestionAnswering,
load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, Bert2Bert) load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, Bert2Rnd)
from .modeling_openai import (OpenAIGPTPreTrainedModel, OpenAIGPTModel, from .modeling_openai import (OpenAIGPTPreTrainedModel, OpenAIGPTModel,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP) load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
......
...@@ -1419,7 +1419,7 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1419,7 +1419,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
@add_start_docstrings("Bert encoder-decoder model for sequence generation.", @add_start_docstrings("Bert encoder-decoder model for sequence generation.",
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING) BERT_INPUTS_DOCSTRING)
class Bert2Bert(BertPreTrainedModel): class Bert2Rnd(BertPreTrainedModel):
r""" r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
...@@ -1434,7 +1434,8 @@ class Bert2Bert(BertPreTrainedModel): ...@@ -1434,7 +1434,8 @@ class Bert2Bert(BertPreTrainedModel):
Examples:: Examples::
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = Bert2Bert.from_pretrained('bert-base-uncased') model = Bert2Rnd.from_pretrained('bert-base-uncased')
# fine-tuning magic happens here
input = tokenizer.encode("Hello, how are you?") input = tokenizer.encode("Hello, how are you?")
outputs = model(input) outputs = model(input)
output_text = tokenize.decode(outputs[0]) output_text = tokenize.decode(outputs[0])
...@@ -1468,4 +1469,4 @@ class Bert2Bert(BertPreTrainedModel): ...@@ -1468,4 +1469,4 @@ class Bert2Bert(BertPreTrainedModel):
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask) head_mask=head_mask)
return decoder_outputs return decoder_outputs[0]
...@@ -29,7 +29,7 @@ if is_torch_available(): ...@@ -29,7 +29,7 @@ if is_torch_available():
from transformers import (BertConfig, BertModel, BertForMaskedLM, from transformers import (BertConfig, BertModel, BertForMaskedLM,
BertForNextSentencePrediction, BertForPreTraining, BertForNextSentencePrediction, BertForPreTraining,
BertForQuestionAnswering, BertForSequenceClassification, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification, BertForMultipleChoice, Bert2Bert) BertForTokenClassification, BertForMultipleChoice, Bert2Rnd)
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
else: else:
pytestmark = pytest.mark.skip("Require Torch") pytestmark = pytest.mark.skip("Require Torch")
...@@ -257,7 +257,7 @@ class BertModelTest(CommonTestCases.CommonModelTester): ...@@ -257,7 +257,7 @@ class BertModelTest(CommonTestCases.CommonModelTester):
def create_and_check_bert2bert(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_bert2bert(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
config.num_choices = self.num_choices config.num_choices = self.num_choices
model = Bert2Bert(config=config) model = Bert2Rnd(config=config)
model.eval() model.eval()
bert2bert_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() bert2bert_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
bert2bert_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() bert2bert_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment