Commit 5e1207b8 authored by thomwolf's avatar thomwolf
Browse files

add attention to all bert models and add test

parent bcc9e93e
...@@ -813,15 +813,20 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -813,15 +813,20 @@ class BertForPreTraining(BertPreTrainedModel):
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertForPreTraining, self).__init__(config) super(BertForPreTraining, self).__init__(config)
self.bert = BertModel(config) self.output_attentions = output_attentions
self.bert = BertModel(config, output_attentions=output_attentions)
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None):
sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, outputs = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False) output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, sequence_output, pooled_output = outputs
else:
sequence_output, pooled_output = outputs
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
if masked_lm_labels is not None and next_sentence_label is not None: if masked_lm_labels is not None and next_sentence_label is not None:
...@@ -830,8 +835,9 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -830,8 +835,9 @@ class BertForPreTraining(BertPreTrainedModel):
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss total_loss = masked_lm_loss + next_sentence_loss
return total_loss return total_loss
else: elif self.output_attentions:
return prediction_scores, seq_relationship_score return all_attentions, prediction_scores, seq_relationship_score
return prediction_scores, seq_relationship_score
class BertForMaskedLM(BertPreTrainedModel): class BertForMaskedLM(BertPreTrainedModel):
...@@ -876,23 +882,29 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -876,23 +882,29 @@ class BertForMaskedLM(BertPreTrainedModel):
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertForMaskedLM, self).__init__(config) super(BertForMaskedLM, self).__init__(config)
self.bert = BertModel(config) self.output_attentions = output_attentions
self.bert = BertModel(config, output_attentions=output_attentions)
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, outputs = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False) output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, sequence_output, _ = outputs
else:
sequence_output, _ = outputs
prediction_scores = self.cls(sequence_output) prediction_scores = self.cls(sequence_output)
if masked_lm_labels is not None: if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
return masked_lm_loss return masked_lm_loss
else: elif self.output_attentions:
return prediction_scores return all_attentions, prediction_scores
return prediction_scores
class BertForNextSentencePrediction(BertPreTrainedModel): class BertForNextSentencePrediction(BertPreTrainedModel):
...@@ -938,23 +950,29 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -938,23 +950,29 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_logits = model(input_ids, token_type_ids, input_mask) seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertForNextSentencePrediction, self).__init__(config) super(BertForNextSentencePrediction, self).__init__(config)
self.bert = BertModel(config) self.output_attentions = output_attentions
self.bert = BertModel(config, output_attentions=output_attentions)
self.cls = BertOnlyNSPHead(config) self.cls = BertOnlyNSPHead(config)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, outputs = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False) output_all_encoded_layers=False)
seq_relationship_score = self.cls( pooled_output) if self.output_attentions:
all_attentions, _, pooled_output = outputs
else:
_, pooled_output = outputs
seq_relationship_score = self.cls(pooled_output)
if next_sentence_label is not None: if next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
return next_sentence_loss return next_sentence_loss
else: elif self.output_attentions:
return seq_relationship_score return all_attentions, seq_relationship_score
return seq_relationship_score
class BertForSequenceClassification(BertPreTrainedModel): class BertForSequenceClassification(BertPreTrainedModel):
...@@ -1002,16 +1020,21 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -1002,16 +1020,21 @@ class BertForSequenceClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask) logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, num_labels): def __init__(self, config, num_labels, output_attentions=False):
super(BertForSequenceClassification, self).__init__(config) super(BertForSequenceClassification, self).__init__(config)
self.output_attentions = output_attentions
self.num_labels = num_labels self.num_labels = num_labels
self.bert = BertModel(config) self.bert = BertModel(config, output_attentions=output_attentions)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels) self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, _, pooled_output = outputs
else:
_, pooled_output = outputs
pooled_output = self.dropout(pooled_output) pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
...@@ -1019,8 +1042,9 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -1019,8 +1042,9 @@ class BertForSequenceClassification(BertPreTrainedModel):
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss return loss
else: elif self.output_attentions:
return logits return all_attentions, logits
return logits
class BertForMultipleChoice(BertPreTrainedModel): class BertForMultipleChoice(BertPreTrainedModel):
...@@ -1067,10 +1091,11 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1067,10 +1091,11 @@ class BertForMultipleChoice(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask) logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, num_choices): def __init__(self, config, num_choices, output_attentions=False):
super(BertForMultipleChoice, self).__init__(config) super(BertForMultipleChoice, self).__init__(config)
self.output_attentions = output_attentions
self.num_choices = num_choices self.num_choices = num_choices
self.bert = BertModel(config) self.bert = BertModel(config, output_attentions=output_attentions)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1) self.classifier = nn.Linear(config.hidden_size, 1)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
...@@ -1079,7 +1104,11 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1079,7 +1104,11 @@ class BertForMultipleChoice(BertPreTrainedModel):
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
_, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) outputs = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, _, pooled_output = outputs
else:
_, pooled_output = outputs
pooled_output = self.dropout(pooled_output) pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, self.num_choices) reshaped_logits = logits.view(-1, self.num_choices)
...@@ -1088,8 +1117,9 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1088,8 +1117,9 @@ class BertForMultipleChoice(BertPreTrainedModel):
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels) loss = loss_fct(reshaped_logits, labels)
return loss return loss
else: elif self.output_attentions:
return reshaped_logits return all_attentions, reshaped_logits
return reshaped_logits
class BertForTokenClassification(BertPreTrainedModel): class BertForTokenClassification(BertPreTrainedModel):
...@@ -1137,16 +1167,21 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1137,16 +1167,21 @@ class BertForTokenClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask) logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, num_labels): def __init__(self, config, num_labels, output_attentions=False):
super(BertForTokenClassification, self).__init__(config) super(BertForTokenClassification, self).__init__(config)
self.output_attentions = output_attentions
self.num_labels = num_labels self.num_labels = num_labels
self.bert = BertModel(config) self.bert = BertModel(config, output_attentions=output_attentions)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels) self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, sequence_output, _ = outputs
else:
sequence_output, _ = outputs
sequence_output = self.dropout(sequence_output) sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
...@@ -1161,8 +1196,9 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1161,8 +1196,9 @@ class BertForTokenClassification(BertPreTrainedModel):
else: else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss return loss
else: elif self.output_attentions:
return logits return all_attentions, logits
return logits
class BertForQuestionAnswering(BertPreTrainedModel): class BertForQuestionAnswering(BertPreTrainedModel):
...@@ -1212,16 +1248,19 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1212,16 +1248,19 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask) start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertForQuestionAnswering, self).__init__(config) super(BertForQuestionAnswering, self).__init__(config)
self.bert = BertModel(config) self.output_attentions = output_attentions
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version self.bert = BertModel(config, output_attentions=output_attentions)
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.qa_outputs = nn.Linear(config.hidden_size, 2) self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, sequence_output, _ = outputs
else:
sequence_output, _ = outputs
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1)
...@@ -1243,5 +1282,6 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1243,5 +1282,6 @@ class BertForQuestionAnswering(BertPreTrainedModel):
end_loss = loss_fct(end_logits, end_positions) end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2 total_loss = (start_loss + end_loss) / 2
return total_loss return total_loss
else: elif self.output_attentions:
return start_logits, end_logits return all_attentions, start_logits, end_logits
return start_logits, end_logits
...@@ -28,7 +28,7 @@ import torch ...@@ -28,7 +28,7 @@ import torch
from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM, from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM,
BertForNextSentencePrediction, BertForPreTraining, BertForNextSentencePrediction, BertForPreTraining,
BertForQuestionAnswering, BertForSequenceClassification, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification) BertForTokenClassification, BertForMultipleChoice)
from pytorch_pretrained_bert.modeling import PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_pretrained_bert.modeling import PRETRAINED_MODEL_ARCHIVE_MAP
...@@ -56,6 +56,7 @@ class BertModelTest(unittest.TestCase): ...@@ -56,6 +56,7 @@ class BertModelTest(unittest.TestCase):
type_sequence_label_size=2, type_sequence_label_size=2,
initializer_range=0.02, initializer_range=0.02,
num_labels=3, num_labels=3,
num_choices=4,
scope=None): scope=None):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -77,6 +78,7 @@ class BertModelTest(unittest.TestCase): ...@@ -77,6 +78,7 @@ class BertModelTest(unittest.TestCase):
self.type_sequence_label_size = type_sequence_label_size self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.num_labels = num_labels self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope self.scope = scope
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
...@@ -92,9 +94,11 @@ class BertModelTest(unittest.TestCase): ...@@ -92,9 +94,11 @@ class BertModelTest(unittest.TestCase):
sequence_labels = None sequence_labels = None
token_labels = None token_labels = None
choice_labels = None
if self.use_labels: if self.use_labels:
sequence_labels = BertModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size) sequence_labels = BertModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = BertModelTest.ids_tensor([self.batch_size], self.num_choices)
config = BertConfig( config = BertConfig(
vocab_size_or_config_json_file=self.vocab_size, vocab_size_or_config_json_file=self.vocab_size,
...@@ -109,14 +113,14 @@ class BertModelTest(unittest.TestCase): ...@@ -109,14 +113,14 @@ class BertModelTest(unittest.TestCase):
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range) initializer_range=self.initializer_range)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result): def check_loss_output(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["loss"].size()), list(result["loss"].size()),
[]) [])
def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertModel(config=config) model = BertModel(config=config)
model.eval() model.eval()
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
...@@ -137,7 +141,7 @@ class BertModelTest(unittest.TestCase): ...@@ -137,7 +141,7 @@ class BertModelTest(unittest.TestCase):
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForMaskedLM(config=config) model = BertForMaskedLM(config=config)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, token_labels) loss = model(input_ids, token_type_ids, input_mask, token_labels)
...@@ -153,7 +157,7 @@ class BertModelTest(unittest.TestCase): ...@@ -153,7 +157,7 @@ class BertModelTest(unittest.TestCase):
list(result["prediction_scores"].size()), list(result["prediction_scores"].size()),
[self.batch_size, self.seq_length, self.vocab_size]) [self.batch_size, self.seq_length, self.vocab_size])
def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForNextSentencePrediction(config=config) model = BertForNextSentencePrediction(config=config)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, sequence_labels) loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
...@@ -170,7 +174,7 @@ class BertModelTest(unittest.TestCase): ...@@ -170,7 +174,7 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, 2]) [self.batch_size, 2])
def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForPreTraining(config=config) model = BertForPreTraining(config=config)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels) loss = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels)
...@@ -191,7 +195,7 @@ class BertModelTest(unittest.TestCase): ...@@ -191,7 +195,7 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, 2]) [self.batch_size, 2])
def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForQuestionAnswering(config=config) model = BertForQuestionAnswering(config=config)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels) loss = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels)
...@@ -212,7 +216,7 @@ class BertModelTest(unittest.TestCase): ...@@ -212,7 +216,7 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, self.seq_length]) [self.batch_size, self.seq_length])
def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForSequenceClassification(config=config, num_labels=self.num_labels) model = BertForSequenceClassification(config=config, num_labels=self.num_labels)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, sequence_labels) loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
...@@ -229,7 +233,7 @@ class BertModelTest(unittest.TestCase): ...@@ -229,7 +233,7 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, self.num_labels]) [self.batch_size, self.num_labels])
def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForTokenClassification(config=config, num_labels=self.num_labels) model = BertForTokenClassification(config=config, num_labels=self.num_labels)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, token_labels) loss = model(input_ids, token_type_ids, input_mask, token_labels)
...@@ -246,6 +250,49 @@ class BertModelTest(unittest.TestCase): ...@@ -246,6 +250,49 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, self.seq_length, self.num_labels]) [self.batch_size, self.seq_length, self.num_labels])
def create_bert_for_multiple_choice(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForMultipleChoice(config=config, num_choices=self.num_choices)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss = model(multiple_choice_inputs_ids,
multiple_choice_token_type_ids,
multiple_choice_input_mask,
choice_labels)
logits = model(multiple_choice_inputs_ids,
multiple_choice_token_type_ids,
multiple_choice_input_mask)
outputs = {
"loss": loss,
"logits": logits,
}
return outputs
def check_bert_for_multiple_choice(self, result):
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.num_choices])
def create_and_check_bert_for_attentions(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification):
if model_class in [BertForSequenceClassification,
BertForTokenClassification]:
model = model_class(config=config, num_labels=self.num_labels, output_attentions=True)
else:
model = model_class(config=config, output_attentions=True)
model.eval()
output = model(input_ids, token_type_ids, input_mask)
attentions = output[0]
self.parent.assertEqual(len(attentions), self.num_hidden_layers)
self.parent.assertListEqual(
list(attentions[0].size()),
[self.batch_size, self.num_attention_heads, self.seq_length, self.seq_length])
def test_default(self): def test_default(self):
self.run_tester(BertModelTest.BertModelTester(self)) self.run_tester(BertModelTest.BertModelTester(self))
...@@ -300,6 +347,12 @@ class BertModelTest(unittest.TestCase): ...@@ -300,6 +347,12 @@ class BertModelTest(unittest.TestCase):
tester.check_bert_for_token_classification_output(output_result) tester.check_bert_for_token_classification_output(output_result)
tester.check_loss_output(output_result) tester.check_loss_output(output_result)
output_result = tester.create_bert_for_multiple_choice(*config_and_inputs)
tester.check_bert_for_multiple_choice(output_result)
tester.check_loss_output(output_result)
tester.create_and_check_bert_for_attentions(*config_and_inputs)
@classmethod @classmethod
def ids_tensor(cls, shape, vocab_size, rng=None, name=None): def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
"""Creates a random int32 tensor of the shape within the vocab size.""" """Creates a random int32 tensor of the shape within the vocab size."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment