Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
94980b52
Unverified
Commit
94980b52
authored
Apr 03, 2019
by
Thomas Wolf
Committed by
GitHub
Apr 03, 2019
Browse files
Merge pull request #404 from CatalinVoss/fix_lm_loss
Fix Language Modeling Loss
parents
9ca25ce8
01520d54
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
4 deletions
+22
-4
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+11
-2
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+11
-2
No files found.
pytorch_pretrained_bert/modeling_gpt2.py
View file @
94980b52
...
@@ -617,8 +617,14 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
...
@@ -617,8 +617,14 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
hidden_states
,
presents
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
)
hidden_states
,
presents
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
if
lm_labels
is
not
None
:
if
lm_labels
is
not
None
:
# Shift so that tokens < n predict n
shift_logits
=
lm_logits
[:,
:
-
1
].
contiguous
()
shift_labels
=
lm_labels
[:,
1
:].
contiguous
()
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
lm_logits
.
view
(
-
1
,
lm_logits
.
size
(
-
1
)),
lm_labels
.
view
(
-
1
))
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
return
loss
return
loss
return
lm_logits
,
presents
return
lm_logits
,
presents
...
@@ -690,8 +696,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
...
@@ -690,8 +696,11 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
losses
=
[]
losses
=
[]
if
lm_labels
is
not
None
:
if
lm_labels
is
not
None
:
shift_logits
=
lm_logits
[:,
:
-
1
].
contiguous
()
shift_labels
=
lm_labels
[:,
1
:].
contiguous
()
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
losses
.
append
(
loss_fct
(
lm_logits
.
view
(
-
1
,
lm_logits
.
size
(
-
1
)),
lm_labels
.
view
(
-
1
)))
losses
.
append
(
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
)))
if
mc_labels
is
not
None
:
if
mc_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss_fct
=
CrossEntropyLoss
()
losses
.
append
(
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
)))
losses
.
append
(
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
)))
...
...
pytorch_pretrained_bert/modeling_openai.py
View file @
94980b52
...
@@ -716,8 +716,14 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
...
@@ -716,8 +716,14 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
)
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
if
lm_labels
is
not
None
:
if
lm_labels
is
not
None
:
# Shift so that tokens < n predict n
shift_logits
=
lm_logits
[:,
:
-
1
].
contiguous
()
shift_labels
=
lm_labels
[:,
1
:].
contiguous
()
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
lm_logits
.
view
(
-
1
,
lm_logits
.
size
(
-
1
)),
lm_labels
.
view
(
-
1
))
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
return
loss
return
loss
return
lm_logits
return
lm_logits
...
@@ -803,8 +809,11 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
...
@@ -803,8 +809,11 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
losses
=
[]
losses
=
[]
if
lm_labels
is
not
None
:
if
lm_labels
is
not
None
:
shift_logits
=
lm_logits
[:,
:
-
1
].
contiguous
()
shift_labels
=
lm_labels
[:,
1
:].
contiguous
()
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
losses
.
append
(
loss_fct
(
lm_logits
.
view
(
-
1
,
lm_logits
.
size
(
-
1
)),
lm_labels
.
view
(
-
1
)))
losses
.
append
(
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
)))
if
mc_labels
is
not
None
:
if
mc_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss_fct
=
CrossEntropyLoss
()
losses
.
append
(
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
)))
losses
.
append
(
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
)))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment