Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ea9dbea9
"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "8d99bffbdcb1db3496fa64c92fe6fe4009b524e1"
Commit
ea9dbea9
authored
May 07, 2019
by
thomwolf
Browse files
update GPT2 loss computation for more flexbility
parent
ce863365
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
7 deletions
+6
-7
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+6
-7
No files found.
pytorch_pretrained_bert/modeling_gpt2.py
View file @
ea9dbea9
...
@@ -336,6 +336,7 @@ class GPT2MultipleChoiceHead(nn.Module):
...
@@ -336,6 +336,7 @@ class GPT2MultipleChoiceHead(nn.Module):
# (bsz, num_choices, 1, hidden_size)
# (bsz, num_choices, 1, hidden_size)
multiple_choice_h
=
hidden_states
.
gather
(
2
,
mc_token_ids
).
squeeze
(
2
)
multiple_choice_h
=
hidden_states
.
gather
(
2
,
mc_token_ids
).
squeeze
(
2
)
# (bsz, num_choices, hidden_size)
# (bsz, num_choices, hidden_size)
multiple_choice_h
=
self
.
dropout
(
multiple_choice_h
.
transpose
(
1
,
2
)).
transpose
(
1
,
2
)
multiple_choice_logits
=
self
.
linear
(
multiple_choice_h
).
squeeze
(
-
1
)
multiple_choice_logits
=
self
.
linear
(
multiple_choice_h
).
squeeze
(
-
1
)
# (bsz, num_choices)
# (bsz, num_choices)
return
multiple_choice_logits
return
multiple_choice_logits
...
@@ -665,9 +666,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
...
@@ -665,9 +666,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
if
lm_labels
is
not
None
:
if
lm_labels
is
not
None
:
# Shift so that tokens < n predict n
# Shift so that tokens < n predict n
shift_logits
=
lm_logits
[:,
:
-
1
].
contiguous
()
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
lm_labels
[:,
1
:].
contiguous
()
shift_labels
=
lm_labels
[...,
1
:].
contiguous
()
# Flatten the tokens
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
...
@@ -746,11 +746,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
...
@@ -746,11 +746,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
losses
=
[]
losses
=
[]
if
lm_labels
is
not
None
:
if
lm_labels
is
not
None
:
shift_logits
=
lm_logits
[
:
,
:
-
1
].
contiguous
()
shift_logits
=
lm_logits
[
...
,
:
-
1
,
:
].
contiguous
()
shift_labels
=
lm_labels
[
:
,
1
:].
contiguous
()
shift_labels
=
lm_labels
[
...
,
1
:].
contiguous
()
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
losses
.
append
(
loss_fct
(
shift_logits
.
view
(
-
1
,
losses
.
append
(
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
)))
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
)))
if
mc_labels
is
not
None
:
if
mc_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss_fct
=
CrossEntropyLoss
()
losses
.
append
(
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
)))
losses
.
append
(
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
)))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment