Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0dd796e3
"vscode:/vscode.git/clone" did not exist on "8bcb37bfb80d77e06001f989ad982c9961a69c31"
Commit
0dd796e3
authored
Mar 24, 2019
by
Catalin Voss
Browse files
Also fix loss function issue with the double head models
parent
472857c4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
2 deletions
+8
-2
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+4
-1
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+4
-1
No files found.
pytorch_pretrained_bert/modeling_gpt2.py
View file @
0dd796e3
...
...
@@ -698,8 +698,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
losses
=
[]
if
lm_labels
is
not
None
:
shift_logits
=
lm_logits
[:,
:
-
1
]
shift_labels
=
lm_labels
[:,
1
:]
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
losses
.
append
(
loss_fct
(
lm_logits
.
view
(
-
1
,
lm_logits
.
size
(
-
1
)),
lm_labels
.
view
(
-
1
)))
losses
.
append
(
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
)))
if
mc_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
losses
.
append
(
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
)))
...
...
pytorch_pretrained_bert/modeling_openai.py
View file @
0dd796e3
...
...
@@ -811,8 +811,11 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
losses
=
[]
if
lm_labels
is
not
None
:
shift_logits
=
lm_logits
[:,
:
-
1
]
shift_labels
=
lm_labels
[:,
1
:]
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
losses
.
append
(
loss_fct
(
lm_logits
.
view
(
-
1
,
lm_logits
.
size
(
-
1
)),
lm_labels
.
view
(
-
1
)))
losses
.
append
(
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
)))
if
mc_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
losses
.
append
(
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
)))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment