Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d86d57fa
Unverified
Commit
d86d57fa
authored
Nov 18, 2020
by
Stas Bekman
Committed by
GitHub
Nov 18, 2020
Browse files
[s2s] distillation apex breaks return_dict obj (#8631)
* apex breaks return_dict obj * style
parent
bf3611b2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
8 deletions
+10
-8
examples/seq2seq/distillation.py
examples/seq2seq/distillation.py
+10
-8
No files found.
examples/seq2seq/distillation.py
View file @
d86d57fa
...
...
@@ -154,7 +154,7 @@ class SummarizationDistiller(SummarizationModule):
output_attentions
=
False
,
use_cache
=
False
,
)
lm_logits
=
student_outputs
.
logits
lm_logits
=
student_outputs
[
"
logits
"
]
# Same cross entropy vs. label smoothing logic as finetune.py
assert
lm_logits
.
shape
[
-
1
]
==
self
.
model
.
config
.
vocab_size
...
...
@@ -171,7 +171,9 @@ class SummarizationDistiller(SummarizationModule):
def
zero_tensor
():
return
torch
.
tensor
(
0.0
).
type_as
(
student_lm_loss
)
teacher_enc_outputs
=
student_outputs
.
encoder_last_hidden_state
# use this unless self.different_base_models
teacher_enc_outputs
=
student_outputs
[
"encoder_last_hidden_state"
]
# use this unless self.different_base_models
hid_loss_enc
,
hid_loss_dec
=
zero_tensor
(),
zero_tensor
()
if
self
.
different_encoder
:
# compute encoder hidden state loss
all_teacher_encoder_outputs
=
self
.
teacher
.
get_encoder
()(
...
...
@@ -180,12 +182,12 @@ class SummarizationDistiller(SummarizationModule):
output_hidden_states
=
self
.
do_calc_hidden_loss
,
)
if
self
.
different_base_models
:
teacher_enc_outputs
=
all_teacher_encoder_outputs
.
last_hidden_state
teacher_enc_outputs
=
all_teacher_encoder_outputs
[
"
last_hidden_state
"
]
elif
self
.
do_calc_hidden_loss
:
hid_loss_enc
=
self
.
calc_hidden_loss
(
src_mask
,
student_outputs
.
encoder_hidden_states
,
all_teacher_encoder_outputs
.
hidden_states
,
student_outputs
[
"
encoder_hidden_states
"
]
,
all_teacher_encoder_outputs
[
"
hidden_states
"
]
,
self
.
e_matches
,
normalize_hidden
=
self
.
hparams
.
normalize_hidden
,
)
...
...
@@ -199,12 +201,12 @@ class SummarizationDistiller(SummarizationModule):
use_cache
=
False
,
# since we are not passing labels, never let this default to True
)
dec_mask
=
decoder_input_ids
.
ne
(
pad_token_id
)
loss_ce
=
self
.
calc_ce_loss
(
dec_mask
,
lm_logits
,
teacher_outputs
.
logits
)
loss_ce
=
self
.
calc_ce_loss
(
dec_mask
,
lm_logits
,
teacher_outputs
[
"
logits
"
]
)
if
self
.
do_calc_hidden_loss
:
# Intermediate supervision of decoder hidden states
hid_loss_dec
=
self
.
calc_hidden_loss
(
dec_mask
,
student_outputs
.
decoder_hidden_states
,
teacher_outputs
.
decoder_hidden_states
,
student_outputs
[
"
decoder_hidden_states
"
]
,
teacher_outputs
[
"
decoder_hidden_states
"
]
,
self
.
d_matches
,
normalize_hidden
=
self
.
hparams
.
normalize_hidden
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment