Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d5d2744a
Unverified
Commit
d5d2744a
authored
Oct 05, 2020
by
Sam Shleifer
Committed by
GitHub
Oct 05, 2020
Browse files
Support T5 Distillation w/hidden state supervision (#7599)
parent
818c294f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
30 deletions
+37
-30
examples/seq2seq/distillation.py
examples/seq2seq/distillation.py
+37
-27
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+0
-3
No files found.
examples/seq2seq/distillation.py
View file @
d5d2744a
...
...
@@ -28,7 +28,7 @@ from lightning_base import generic_train # noqa
class
BartSummarizationDistiller
(
SummarizationModule
):
"""Supports Bart, Pegasus and other models that inherit from Bart."""
loss_names
=
[
"loss"
,
"ce_loss"
,
"mlm_loss"
,
"enc_mse_loss"
,
"hid_loss_enc"
,
"hid_loss_dec"
]
loss_names
=
[
"loss"
,
"ce_loss"
,
"mlm_loss"
,
"hid_loss_enc"
,
"hid_loss_dec"
]
def
__init__
(
self
,
hparams
):
assert
Path
(
hparams
.
data_dir
).
exists
()
...
...
@@ -46,9 +46,19 @@ class BartSummarizationDistiller(SummarizationModule):
if
hparams
.
length_penalty
!=
-
1
:
student
.
config
.
length_penalty
=
hparams
.
length_penalty
super
().
__init__
(
hparams
,
model
=
student
,
config
=
student
.
config
)
model_type
=
student
.
config
.
model_type
self
.
e_layer_ids
,
self
.
d_layer_ids
=
e_layer_ids
,
d_layer_ids
# type: List[int], List[int]
self
.
different_encoder
=
hparams
.
student_encoder_layers
!=
teacher
.
config
.
encoder_layers
self
.
different_decoder
=
hparams
.
student_decoder_layers
!=
teacher
.
config
.
decoder_layers
if
model_type
==
"t5"
:
teacher_encoder_layers
=
len
(
teacher
.
get_encoder
().
block
)
teacher_decoder_layers
=
len
(
teacher
.
get_decoder
().
block
)
else
:
teacher_encoder_layers
=
teacher
.
config
.
encoder_layers
teacher_decoder_layers
=
teacher
.
config
.
decoder_layers
self
.
different_encoder
=
hparams
.
student_encoder_layers
!=
teacher_encoder_layers
self
.
different_decoder
=
hparams
.
student_decoder_layers
!=
teacher_decoder_layers
self
.
teacher
=
teacher
freeze_params
(
self
.
teacher
)
...
...
@@ -59,17 +69,17 @@ class BartSummarizationDistiller(SummarizationModule):
del
self
.
teacher
.
encoder
# Intermediate supervision: Decide which layers to supervise
if
hparams
.
supervise_forward
:
self
.
d
_matches
=
get_layers_to_supervise
(
n_student
=
len
(
self
.
d_layer_ids
),
n_teacher
=
self
.
teacher
.
config
.
decoder_layers
)
else
:
self
.
e
_matches
=
get_layers_to_supervise
(
n_student
=
len
(
self
.
e_layer_ids
),
n_teacher
=
teacher_encoder_layers
)
self
.
d_matches
=
get_layers_to_supervise
(
n_student
=
len
(
self
.
d_layer_ids
),
n_teacher
=
teacher
_
decoder_layers
)
else
:
# student layer should emulate hidden states of the teacher layer it was copied from
self
.
e_matches
=
self
.
e_layer_ids
self
.
d_matches
=
self
.
d_layer_ids
self
.
ce_loss_fct
=
nn
.
KLDivLoss
(
reduction
=
"batchmean"
)
self
.
temperature
=
2.0
self
.
alpha_mlm
=
hparams
.
alpha_mlm
self
.
alpha_ce
=
hparams
.
alpha_ce
self
.
alpha_hid
=
hparams
.
alpha_hid
self
.
alpha_encoder_loss
=
hparams
.
alpha_encoder_loss
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
...
...
@@ -129,7 +139,7 @@ class BartSummarizationDistiller(SummarizationModule):
output_hidden_states
=
True
,
output_attentions
=
False
,
use_cache
=
False
,
)
# TODO(@sshleifer): return_dict=True cleanup
)
# Same cross entropy vs. label smoothing logic as finetune.py
assert
lm_logits
.
shape
[
-
1
]
==
self
.
model
.
config
.
vocab_size
...
...
@@ -146,30 +156,32 @@ class BartSummarizationDistiller(SummarizationModule):
def
zero_tensor
():
return
torch
.
tensor
(
0.0
).
type_as
(
student_lm_loss
)
loss_encoder
,
hid_loss_enc
,
hid_loss_dec
=
zero_tensor
(),
zero_tensor
()
,
zero_tensor
()
if
self
.
different_encoder
:
hid_loss_enc
,
hid_loss_dec
=
zero_tensor
(),
zero_tensor
()
if
self
.
different_encoder
:
# compute encoder hidden state loss
with
torch
.
no_grad
():
teacher_enc_
outputs
,
teacher_enc_hid
,
_
=
self
.
teacher
.
get_encoder
()(
input_ids
,
attention_mask
=
src_mask
,
output_hidden_states
=
True
)
# DEPRECATE THIS
if
self
.
hparams
.
alpha_encoder_loss
>
0
:
loss_encoder
=
self
.
calc_mse_loss
(
enc_outputs
,
teacher_enc_outputs
,
src_mask
)
hid_loss_enc
=
self
.
calc_hidden_loss
(
src_mask
,
enc_hidden_state
,
teacher_enc_hid
,
self
.
e_layer_ids
)
teacher_enc_outputs
=
(
enc_outputs
,)
assert
isinstance
(
teacher_enc_outputs
,
tuple
),
type
(
teacher_enc_outputs
)
teacher_enc_
hid
=
self
.
teacher
.
get_encoder
()(
input_ids
,
attention_mask
=
src_mask
,
output_hidden_states
=
True
,
return_dict
=
True
)
.
hidden_states
hid_loss_enc
=
self
.
calc_hidden_loss
(
src_mask
,
enc_hidden_state
,
teacher_enc_hid
,
self
.
e_matches
,
normalize_hidden
=
self
.
hparams
.
normalize_hidden
,
)
with
torch
.
no_grad
():
tloss
,
tlogits
,
tdec_hidden
,
_
=
self
.
teacher
(
outputs
=
self
.
teacher
(
input_ids
,
attention_mask
=
src_mask
,
encoder_outputs
=
teacher_
enc_outputs
,
encoder_outputs
=
(
enc_outputs
,
),
decoder_input_ids
=
decoder_input_ids
,
lm_labels
=
labels
,
output_hidden_states
=
True
,
return_dict
=
True
,
)
tlogits
,
tdec_hidden
=
outputs
.
logits
,
outputs
.
decoder_hidden_states
dec_mask
=
decoder_input_ids
.
ne
(
pad_token_id
)
loss_ce
=
self
.
calc_ce_loss
(
dec_mask
,
lm_logits
,
tlogits
)
if
self
.
alpha_hid
>
0
:
# Intermediate supervision of decoder hidden states
...
...
@@ -180,10 +192,9 @@ class BartSummarizationDistiller(SummarizationModule):
blended_loss
=
(
self
.
alpha_ce
*
loss_ce
+
self
.
alpha_mlm
*
student_lm_loss
+
self
.
hparams
.
alpha_encoder_loss
*
loss_encoder
+
self
.
hparams
.
alpha_hid
*
(
hid_loss_enc
+
hid_loss_dec
)
)
return
blended_loss
,
loss_ce
,
student_lm_loss
,
loss_encoder
,
hid_loss_enc
,
hid_loss_dec
return
blended_loss
,
loss_ce
,
student_lm_loss
,
hid_loss_enc
,
hid_loss_dec
@
staticmethod
def
calc_hidden_loss
(
attention_mask
,
hidden_states
,
hidden_states_T
,
matches
,
normalize_hidden
):
...
...
@@ -207,7 +218,6 @@ def add_distill_args(parser):
parser
.
add_argument
(
"--teacher"
,
type
=
str
)
parser
.
add_argument
(
"--alpha_ce"
,
default
=
0.8
,
type
=
float
)
parser
.
add_argument
(
"--alpha_mlm"
,
default
=
0.2
,
type
=
float
)
parser
.
add_argument
(
"--alpha_encoder_loss"
,
default
=
0.0
,
type
=
float
)
parser
.
add_argument
(
"--alpha_hid"
,
default
=
0.0
,
type
=
float
,
required
=
False
)
parser
.
add_argument
(
"--student_decoder_layers"
,
default
=
12
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--student_encoder_layers"
,
default
=
12
,
type
=
int
,
required
=
False
)
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
d5d2744a
...
...
@@ -86,7 +86,6 @@ CHEAP_ARGS = {
"n_val"
:
-
1
,
"n_test"
:
-
1
,
"student_encoder_layers"
:
1
,
"alpha_encoder_loss"
:
0.0
,
"freeze_encoder"
:
False
,
"auto_scale_batch_size"
:
False
,
}
...
...
@@ -230,7 +229,6 @@ class TestSummarizationDistiller(unittest.TestCase):
evaluate_checkpoint
(
ckpts
[
0
],
dest_dir
=
Path
(
tempfile
.
mkdtemp
()))
@
unittest
.
skip
(
"T5 distillation is broken at the moment"
)
def
test_distill_t5
(
self
):
updates
=
dict
(
student_encoder_layers
=
1
,
...
...
@@ -255,7 +253,6 @@ class TestSummarizationDistiller(unittest.TestCase):
model_name_or_path
=
"sshleifer/tinier_bart"
,
teacher
=
CHEAP_ARGS
[
"model_name_or_path"
],
val_check_interval
=
0.5
,
alpha_encoder_loss
=
0.4
,
)
default_updates
.
update
(
updates
)
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment