Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
870d7163
Commit
870d7163
authored
Nov 26, 2018
by
thomwolf
Browse files
fixing target size in crossentropy losses
parent
982339d8
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
5 deletions
+6
-5
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+6
-5
No files found.
pytorch_pretrained_bert/modeling.py
View file @
870d7163
...
...
@@ -678,8 +678,8 @@ class BertForPreTraining(PreTrainedBertModel):
if
masked_lm_labels
is
not
None
and
next_sentence_label
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
masked_lm_loss
=
loss_fct
(
prediction_scores
,
masked_lm_labels
)
next_sentence_loss
=
loss_fct
(
seq_relationship_score
,
next_sentence_label
)
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
)
,
masked_lm_labels
(
-
1
)
)
next_sentence_loss
=
loss_fct
(
seq_relationship_score
.
view
(
-
1
,
2
)
,
next_sentence_label
.
view
(
-
1
)
)
total_loss
=
masked_lm_loss
+
next_sentence_loss
return
total_loss
else
:
...
...
@@ -741,7 +741,7 @@ class BertForMaskedLM(PreTrainedBertModel):
if
masked_lm_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
masked_lm_loss
=
loss_fct
(
prediction_scores
,
masked_lm_labels
)
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
)
,
masked_lm_labels
.
view
(
-
1
)
)
return
masked_lm_loss
else
:
return
prediction_scores
...
...
@@ -803,7 +803,7 @@ class BertForNextSentencePrediction(PreTrainedBertModel):
if
next_sentence_label
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
next_sentence_loss
=
loss_fct
(
seq_relationship_score
,
next_sentence_label
)
next_sentence_loss
=
loss_fct
(
seq_relationship_score
.
view
(
-
1
,
2
)
,
next_sentence_label
.
view
(
-
1
)
)
return
next_sentence_loss
else
:
return
seq_relationship_score
...
...
@@ -856,6 +856,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
"""
def
__init__
(
self
,
config
,
num_labels
=
2
):
super
(
BertForSequenceClassification
,
self
).
__init__
(
config
)
self
.
num_labels
=
num_labels
self
.
bert
=
BertModel
(
config
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
num_labels
)
...
...
@@ -868,7 +869,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
,
labels
)
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
)
)
return
loss
,
logits
else
:
return
logits
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment