"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "afe6f744c0d1c763dd544f58d7b4286c8e81b049"
Commit 870d7163 authored by thomwolf's avatar thomwolf
Browse files

fixing target size in crossentropy losses

parent 982339d8
...@@ -678,8 +678,8 @@ class BertForPreTraining(PreTrainedBertModel): ...@@ -678,8 +678,8 @@ class BertForPreTraining(PreTrainedBertModel):
if masked_lm_labels is not None and next_sentence_label is not None: if masked_lm_labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores, masked_lm_labels) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels(-1))
next_sentence_loss = loss_fct(seq_relationship_score, next_sentence_label) next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss total_loss = masked_lm_loss + next_sentence_loss
return total_loss return total_loss
else: else:
...@@ -741,7 +741,7 @@ class BertForMaskedLM(PreTrainedBertModel): ...@@ -741,7 +741,7 @@ class BertForMaskedLM(PreTrainedBertModel):
if masked_lm_labels is not None: if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores, masked_lm_labels) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
return masked_lm_loss return masked_lm_loss
else: else:
return prediction_scores return prediction_scores
...@@ -803,7 +803,7 @@ class BertForNextSentencePrediction(PreTrainedBertModel): ...@@ -803,7 +803,7 @@ class BertForNextSentencePrediction(PreTrainedBertModel):
if next_sentence_label is not None: if next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
next_sentence_loss = loss_fct(seq_relationship_score, next_sentence_label) next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
return next_sentence_loss return next_sentence_loss
else: else:
return seq_relationship_score return seq_relationship_score
...@@ -856,6 +856,7 @@ class BertForSequenceClassification(PreTrainedBertModel): ...@@ -856,6 +856,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
""" """
def __init__(self, config, num_labels=2): def __init__(self, config, num_labels=2):
super(BertForSequenceClassification, self).__init__(config) super(BertForSequenceClassification, self).__init__(config)
self.num_labels = num_labels
self.bert = BertModel(config) self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels) self.classifier = nn.Linear(config.hidden_size, num_labels)
...@@ -868,7 +869,7 @@ class BertForSequenceClassification(PreTrainedBertModel): ...@@ -868,7 +869,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits, labels) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss, logits return loss, logits
else: else:
return logits return logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment