Commit 213981d8 authored by thomwolf's avatar thomwolf
Browse files

updating bert API

parent 2b56e988
...@@ -814,31 +814,28 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -814,31 +814,28 @@ class BertForMaskedLM(BertPreTrainedModel):
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, output_attentions=False, keep_multihead_output=False): def __init__(self, config, output_attentions=False, output_hidden_states=False):
super(BertForMaskedLM, self).__init__(config) super(BertForMaskedLM, self).__init__(config)
self.output_attentions = output_attentions self.output_attentions = output_attentions
self.bert = BertModel(config, output_attentions=output_attentions, self.output_hidden_states = output_hidden_states
keep_multihead_output=keep_multihead_output)
self.bert = BertModel(config, output_attentions=output_attentions )
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
output_all_encoded_layers=False,
head_mask=head_mask) sequence_output = outputs[0]
if self.output_attentions:
all_attentions, sequence_output, _ = outputs
else:
sequence_output, _ = outputs
prediction_scores = self.cls(sequence_output) prediction_scores = self.cls(sequence_output)
outputs = [prediction_scores] + outputs[2:] # Add hidden states and attention is they are here
if masked_lm_labels is not None: if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
return masked_lm_loss outputs = [masked_lm_loss] + outputs
elif self.output_attentions:
return all_attentions, prediction_scores return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
return prediction_scores
class BertForNextSentencePrediction(BertPreTrainedModel): class BertForNextSentencePrediction(BertPreTrainedModel):
...@@ -889,31 +886,29 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -889,31 +886,29 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_logits = model(input_ids, token_type_ids, input_mask) seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, output_attentions=False, keep_multihead_output=False): def __init__(self, config, output_attentions=False, output_hidden_states=False):
super(BertForNextSentencePrediction, self).__init__(config) super(BertForNextSentencePrediction, self).__init__(config)
self.output_attentions = output_attentions self.output_attentions = output_attentions
self.bert = BertModel(config, output_attentions=output_attentions, self.output_hidden_states = output_hidden_states
keep_multihead_output=keep_multihead_output)
self.bert = BertModel(config, output_attentions=output_attentions)
self.cls = BertOnlyNSPHead(config) self.cls = BertOnlyNSPHead(config)
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, head_mask=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
output_all_encoded_layers=False, pooled_output = outputs[1]
head_mask=head_mask)
if self.output_attentions:
all_attentions, _, pooled_output = outputs
else:
_, pooled_output = outputs
seq_relationship_score = self.cls(pooled_output) seq_relationship_score = self.cls(pooled_output)
outputs = [seq_relationship_score] + outputs[2:] # add hidden states and attention if they are here
if next_sentence_label is not None: if next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
return next_sentence_loss outputs = [next_sentence_loss] + outputs
elif self.output_attentions:
return all_attentions, seq_relationship_score return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
return seq_relationship_score
class BertForSequenceClassification(BertPreTrainedModel): class BertForSequenceClassification(BertPreTrainedModel):
...@@ -966,25 +961,27 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -966,25 +961,27 @@ class BertForSequenceClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask) logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, num_labels=2, output_attentions=False, keep_multihead_output=False): def __init__(self, config, num_labels=2, output_attentions=False, output_hidden_states=False):
super(BertForSequenceClassification, self).__init__(config) super(BertForSequenceClassification, self).__init__(config)
self.output_attentions = output_attentions self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.num_labels = num_labels self.num_labels = num_labels
self.bert = BertModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output) self.bert = BertModel(config, output_attentions=output_attentions)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels) self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, head_mask=head_mask) outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
if self.output_attentions: pooled_output = outputs[1]
all_attentions, _, pooled_output = outputs
else:
_, pooled_output = outputs
pooled_output = self.dropout(pooled_output) pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
outputs = [logits] + outputs[2:] # add hidden states and attention if they are here
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.num_labels == 1:
# We are doing regression # We are doing regression
...@@ -993,10 +990,9 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -993,10 +990,9 @@ class BertForSequenceClassification(BertPreTrainedModel):
else: else:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss outputs = [loss] + outputs
elif self.output_attentions:
return all_attentions, logits return outputs # (loss), logits, (hidden_states), (attentions)
return logits
class BertForMultipleChoice(BertPreTrainedModel): class BertForMultipleChoice(BertPreTrainedModel):
...@@ -1048,36 +1044,37 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1048,36 +1044,37 @@ class BertForMultipleChoice(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask) logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, num_choices=2, output_attentions=False, keep_multihead_output=False): def __init__(self, config, num_choices=2, output_attentions=False, output_hidden_states=False):
super(BertForMultipleChoice, self).__init__(config) super(BertForMultipleChoice, self).__init__(config)
self.output_attentions = output_attentions self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.num_choices = num_choices self.num_choices = num_choices
self.bert = BertModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output) self.bert = BertModel(config, output_attentions=output_attentions)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1) self.classifier = nn.Linear(config.hidden_size, 1)
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
outputs = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False, head_mask=head_mask) outputs = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, head_mask=head_mask)
if self.output_attentions: pooled_output = outputs[1]
all_attentions, _, pooled_output = outputs
else:
_, pooled_output = outputs
pooled_output = self.dropout(pooled_output) pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, self.num_choices) reshaped_logits = logits.view(-1, self.num_choices)
outputs = [reshaped_logits] + outputs[2:] # add hidden states and attention if they are here
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels) loss = loss_fct(reshaped_logits, labels)
return loss outputs = [loss] + outputs
elif self.output_attentions:
return all_attentions, reshaped_logits return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
return reshaped_logits
class BertForTokenClassification(BertPreTrainedModel): class BertForTokenClassification(BertPreTrainedModel):
...@@ -1130,25 +1127,26 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1130,25 +1127,26 @@ class BertForTokenClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask) logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, num_labels=2, output_attentions=False, keep_multihead_output=False): def __init__(self, config, num_labels=2, output_attentions=False, output_hidden_states=False):
super(BertForTokenClassification, self).__init__(config) super(BertForTokenClassification, self).__init__(config)
self.output_attentions = output_attentions self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.num_labels = num_labels self.num_labels = num_labels
self.bert = BertModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output) self.bert = BertModel(config, output_attentions=output_attentions)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels) self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, head_mask=head_mask) outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
if self.output_attentions: sequence_output = outputs[0]
all_attentions, sequence_output, _ = outputs
else:
sequence_output, _ = outputs
sequence_output = self.dropout(sequence_output) sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
outputs = [logits] + outputs[2:] # add hidden states and attention if they are here
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss # Only keep active parts of the loss
...@@ -1159,10 +1157,9 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1159,10 +1157,9 @@ class BertForTokenClassification(BertPreTrainedModel):
loss = loss_fct(active_logits, active_labels) loss = loss_fct(active_logits, active_labels)
else: else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss outputs = [loss] + outputs
elif self.output_attentions:
return all_attentions, logits return outputs # (loss), logits, (hidden_states), (attentions)
return logits
class BertForQuestionAnswering(BertPreTrainedModel): class BertForQuestionAnswering(BertPreTrainedModel):
...@@ -1217,28 +1214,26 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1217,28 +1214,26 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask) start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, output_attentions=False, keep_multihead_output=False): def __init__(self, config, output_attentions=False, output_hidden_states=False):
super(BertForQuestionAnswering, self).__init__(config) super(BertForQuestionAnswering, self).__init__(config)
self.output_attentions = output_attentions self.output_attentions = output_attentions
self.bert = BertModel(config, output_attentions=output_attentions, self.output_hidden_states = output_hidden_states
keep_multihead_output=keep_multihead_output) self.bert = BertModel(config, output_attentions=output_attentions)
self.qa_outputs = nn.Linear(config.hidden_size, 2) self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,
end_positions=None, head_mask=None): end_positions=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
output_all_encoded_layers=False, sequence_output = outputs[0]
head_mask=head_mask)
if self.output_attentions:
all_attentions, sequence_output, _ = outputs
else:
sequence_output, _ = outputs
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1)
outputs = [start_logits, end_logits] + outputs[2:]
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension # If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1: if len(start_positions.size()) > 1:
...@@ -1254,7 +1249,6 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1254,7 +1249,6 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions) end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2 total_loss = (start_loss + end_loss) / 2
return total_loss outputs = [total_loss] + outputs
elif self.output_attentions:
return all_attentions, start_logits, end_logits return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
return start_logits, end_logits
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment