Commit 213981d8 authored by thomwolf's avatar thomwolf
Browse files

updating bert API

parent 2b56e988
......@@ -814,31 +814,28 @@ class BertForMaskedLM(BertPreTrainedModel):
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
def __init__(self, config, output_attentions=False, output_hidden_states=False):
super(BertForMaskedLM, self).__init__(config)
self.output_attentions = output_attentions
self.bert = BertModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.output_hidden_states = output_hidden_states
self.bert = BertModel(config, output_attentions=output_attentions )
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False,
head_mask=head_mask)
if self.output_attentions:
all_attentions, sequence_output, _ = outputs
else:
sequence_output, _ = outputs
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
outputs = [prediction_scores] + outputs[2:] # Add hidden states and attention is they are here
if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
return masked_lm_loss
elif self.output_attentions:
return all_attentions, prediction_scores
return prediction_scores
outputs = [masked_lm_loss] + outputs
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
class BertForNextSentencePrediction(BertPreTrainedModel):
......@@ -889,31 +886,29 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
def __init__(self, config, output_attentions=False, output_hidden_states=False):
super(BertForNextSentencePrediction, self).__init__(config)
self.output_attentions = output_attentions
self.bert = BertModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.output_hidden_states = output_hidden_states
self.bert = BertModel(config, output_attentions=output_attentions)
self.cls = BertOnlyNSPHead(config)
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False,
head_mask=head_mask)
if self.output_attentions:
all_attentions, _, pooled_output = outputs
else:
_, pooled_output = outputs
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
pooled_output = outputs[1]
seq_relationship_score = self.cls(pooled_output)
outputs = [seq_relationship_score] + outputs[2:] # add hidden states and attention if they are here
if next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
return next_sentence_loss
elif self.output_attentions:
return all_attentions, seq_relationship_score
return seq_relationship_score
outputs = [next_sentence_loss] + outputs
return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
class BertForSequenceClassification(BertPreTrainedModel):
......@@ -966,25 +961,27 @@ class BertForSequenceClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels=2, output_attentions=False, keep_multihead_output=False):
def __init__(self, config, num_labels=2, output_attentions=False, output_hidden_states=False):
super(BertForSequenceClassification, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.num_labels = num_labels
self.bert = BertModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.bert = BertModel(config, output_attentions=output_attentions)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, head_mask=head_mask)
if self.output_attentions:
all_attentions, _, pooled_output = outputs
else:
_, pooled_output = outputs
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
outputs = [logits] + outputs[2:] # add hidden states and attention if they are here
if labels is not None:
if self.num_labels == 1:
# We are doing regression
......@@ -993,10 +990,9 @@ class BertForSequenceClassification(BertPreTrainedModel):
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss
elif self.output_attentions:
return all_attentions, logits
return logits
outputs = [loss] + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
class BertForMultipleChoice(BertPreTrainedModel):
......@@ -1048,36 +1044,37 @@ class BertForMultipleChoice(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_choices=2, output_attentions=False, keep_multihead_output=False):
def __init__(self, config, num_choices=2, output_attentions=False, output_hidden_states=False):
super(BertForMultipleChoice, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.num_choices = num_choices
self.bert = BertModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.bert = BertModel(config, output_attentions=output_attentions)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
outputs = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False, head_mask=head_mask)
if self.output_attentions:
all_attentions, _, pooled_output = outputs
else:
_, pooled_output = outputs
outputs = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, head_mask=head_mask)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, self.num_choices)
outputs = [reshaped_logits] + outputs[2:] # add hidden states and attention if they are here
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
return loss
elif self.output_attentions:
return all_attentions, reshaped_logits
return reshaped_logits
outputs = [loss] + outputs
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
class BertForTokenClassification(BertPreTrainedModel):
......@@ -1130,25 +1127,26 @@ class BertForTokenClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels=2, output_attentions=False, keep_multihead_output=False):
def __init__(self, config, num_labels=2, output_attentions=False, output_hidden_states=False):
super(BertForTokenClassification, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.num_labels = num_labels
self.bert = BertModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.bert = BertModel(config, output_attentions=output_attentions)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, head_mask=head_mask)
if self.output_attentions:
all_attentions, sequence_output, _ = outputs
else:
sequence_output, _ = outputs
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
outputs = [logits] + outputs[2:] # add hidden states and attention if they are here
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
......@@ -1159,10 +1157,9 @@ class BertForTokenClassification(BertPreTrainedModel):
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss
elif self.output_attentions:
return all_attentions, logits
return logits
outputs = [loss] + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
class BertForQuestionAnswering(BertPreTrainedModel):
......@@ -1217,28 +1214,26 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
def __init__(self, config, output_attentions=False, output_hidden_states=False):
super(BertForQuestionAnswering, self).__init__(config)
self.output_attentions = output_attentions
self.bert = BertModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.output_hidden_states = output_hidden_states
self.bert = BertModel(config, output_attentions=output_attentions)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,
end_positions=None, head_mask=None):
outputs = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False,
head_mask=head_mask)
if self.output_attentions:
all_attentions, sequence_output, _ = outputs
else:
sequence_output, _ = outputs
outputs = self.bert(input_ids, token_type_ids, attention_mask, head_mask=head_mask)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
outputs = [start_logits, end_logits] + outputs[2:]
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
......@@ -1254,7 +1249,6 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
return total_loss
elif self.output_attentions:
return all_attentions, start_logits, end_logits
return start_logits, end_logits
outputs = [total_loss] + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment