Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d5d2744a
Unverified
Commit
d5d2744a
authored
Oct 05, 2020
by
Sam Shleifer
Committed by
GitHub
Oct 05, 2020
Browse files
Support T5 Distillation w/hidden state supervision (#7599)
parent
818c294f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
30 deletions
+37
-30
examples/seq2seq/distillation.py
examples/seq2seq/distillation.py
+37
-27
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+0
-3
No files found.
examples/seq2seq/distillation.py
View file @
d5d2744a
...
@@ -28,7 +28,7 @@ from lightning_base import generic_train # noqa
...
@@ -28,7 +28,7 @@ from lightning_base import generic_train # noqa
class
BartSummarizationDistiller
(
SummarizationModule
):
class
BartSummarizationDistiller
(
SummarizationModule
):
"""Supports Bart, Pegasus and other models that inherit from Bart."""
"""Supports Bart, Pegasus and other models that inherit from Bart."""
loss_names
=
[
"loss"
,
"ce_loss"
,
"mlm_loss"
,
"enc_mse_loss"
,
"hid_loss_enc"
,
"hid_loss_dec"
]
loss_names
=
[
"loss"
,
"ce_loss"
,
"mlm_loss"
,
"hid_loss_enc"
,
"hid_loss_dec"
]
def
__init__
(
self
,
hparams
):
def
__init__
(
self
,
hparams
):
assert
Path
(
hparams
.
data_dir
).
exists
()
assert
Path
(
hparams
.
data_dir
).
exists
()
...
@@ -46,9 +46,19 @@ class BartSummarizationDistiller(SummarizationModule):
...
@@ -46,9 +46,19 @@ class BartSummarizationDistiller(SummarizationModule):
if
hparams
.
length_penalty
!=
-
1
:
if
hparams
.
length_penalty
!=
-
1
:
student
.
config
.
length_penalty
=
hparams
.
length_penalty
student
.
config
.
length_penalty
=
hparams
.
length_penalty
super
().
__init__
(
hparams
,
model
=
student
,
config
=
student
.
config
)
super
().
__init__
(
hparams
,
model
=
student
,
config
=
student
.
config
)
model_type
=
student
.
config
.
model_type
self
.
e_layer_ids
,
self
.
d_layer_ids
=
e_layer_ids
,
d_layer_ids
# type: List[int], List[int]
self
.
e_layer_ids
,
self
.
d_layer_ids
=
e_layer_ids
,
d_layer_ids
# type: List[int], List[int]
self
.
different_encoder
=
hparams
.
student_encoder_layers
!=
teacher
.
config
.
encoder_layers
self
.
different_decoder
=
hparams
.
student_decoder_layers
!=
teacher
.
config
.
decoder_layers
if
model_type
==
"t5"
:
teacher_encoder_layers
=
len
(
teacher
.
get_encoder
().
block
)
teacher_decoder_layers
=
len
(
teacher
.
get_decoder
().
block
)
else
:
teacher_encoder_layers
=
teacher
.
config
.
encoder_layers
teacher_decoder_layers
=
teacher
.
config
.
decoder_layers
self
.
different_encoder
=
hparams
.
student_encoder_layers
!=
teacher_encoder_layers
self
.
different_decoder
=
hparams
.
student_decoder_layers
!=
teacher_decoder_layers
self
.
teacher
=
teacher
self
.
teacher
=
teacher
freeze_params
(
self
.
teacher
)
freeze_params
(
self
.
teacher
)
...
@@ -59,17 +69,17 @@ class BartSummarizationDistiller(SummarizationModule):
...
@@ -59,17 +69,17 @@ class BartSummarizationDistiller(SummarizationModule):
del
self
.
teacher
.
encoder
del
self
.
teacher
.
encoder
# Intermediate supervision: Decide which layers to supervise
# Intermediate supervision: Decide which layers to supervise
if
hparams
.
supervise_forward
:
if
hparams
.
supervise_forward
:
self
.
d
_matches
=
get_layers_to_supervise
(
self
.
e
_matches
=
get_layers_to_supervise
(
n_student
=
len
(
self
.
e_layer_ids
),
n_teacher
=
teacher_encoder_layers
)
n_student
=
len
(
self
.
d_layer_ids
),
n_teacher
=
self
.
teacher
.
config
.
decoder_layers
self
.
d_matches
=
get_layers_to_supervise
(
n_student
=
len
(
self
.
d_layer_ids
),
n_teacher
=
teacher
_
decoder_layers
)
)
else
:
# student layer should emulate hidden states of the teacher layer it was copied from
else
:
self
.
e_matches
=
self
.
e_layer_ids
self
.
d_matches
=
self
.
d_layer_ids
self
.
d_matches
=
self
.
d_layer_ids
self
.
ce_loss_fct
=
nn
.
KLDivLoss
(
reduction
=
"batchmean"
)
self
.
ce_loss_fct
=
nn
.
KLDivLoss
(
reduction
=
"batchmean"
)
self
.
temperature
=
2.0
self
.
temperature
=
2.0
self
.
alpha_mlm
=
hparams
.
alpha_mlm
self
.
alpha_mlm
=
hparams
.
alpha_mlm
self
.
alpha_ce
=
hparams
.
alpha_ce
self
.
alpha_ce
=
hparams
.
alpha_ce
self
.
alpha_hid
=
hparams
.
alpha_hid
self
.
alpha_hid
=
hparams
.
alpha_hid
self
.
alpha_encoder_loss
=
hparams
.
alpha_encoder_loss
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -129,7 +139,7 @@ class BartSummarizationDistiller(SummarizationModule):
...
@@ -129,7 +139,7 @@ class BartSummarizationDistiller(SummarizationModule):
output_hidden_states
=
True
,
output_hidden_states
=
True
,
output_attentions
=
False
,
output_attentions
=
False
,
use_cache
=
False
,
use_cache
=
False
,
)
# TODO(@sshleifer): return_dict=True cleanup
)
# Same cross entropy vs. label smoothing logic as finetune.py
# Same cross entropy vs. label smoothing logic as finetune.py
assert
lm_logits
.
shape
[
-
1
]
==
self
.
model
.
config
.
vocab_size
assert
lm_logits
.
shape
[
-
1
]
==
self
.
model
.
config
.
vocab_size
...
@@ -146,30 +156,32 @@ class BartSummarizationDistiller(SummarizationModule):
...
@@ -146,30 +156,32 @@ class BartSummarizationDistiller(SummarizationModule):
def
zero_tensor
():
def
zero_tensor
():
return
torch
.
tensor
(
0.0
).
type_as
(
student_lm_loss
)
return
torch
.
tensor
(
0.0
).
type_as
(
student_lm_loss
)
loss_encoder
,
hid_loss_enc
,
hid_loss_dec
=
zero_tensor
(),
zero_tensor
()
,
zero_tensor
()
hid_loss_enc
,
hid_loss_dec
=
zero_tensor
(),
zero_tensor
()
if
self
.
different_encoder
:
if
self
.
different_encoder
:
# compute encoder hidden state loss
with
torch
.
no_grad
():
with
torch
.
no_grad
():
teacher_enc_
outputs
,
teacher_enc_hid
,
_
=
self
.
teacher
.
get_encoder
()(
teacher_enc_
hid
=
self
.
teacher
.
get_encoder
()(
input_ids
,
attention_mask
=
src_mask
,
output_hidden_states
=
True
input_ids
,
attention_mask
=
src_mask
,
output_hidden_states
=
True
,
return_dict
=
True
)
)
.
hidden_states
# DEPRECATE THIS
if
self
.
hparams
.
alpha_encoder_loss
>
0
:
hid_loss_enc
=
self
.
calc_hidden_loss
(
loss_encoder
=
self
.
calc_mse_loss
(
enc_outputs
,
teacher_enc_outputs
,
src_mask
)
src_mask
,
enc_hidden_state
,
hid_loss_enc
=
self
.
calc_hidden_loss
(
src_mask
,
enc_hidden_state
,
teacher_enc_hid
,
self
.
e_layer_ids
)
teacher_enc_hid
,
self
.
e_matches
,
teacher_enc_outputs
=
(
enc_outputs
,)
normalize_hidden
=
self
.
hparams
.
normalize_hidden
,
assert
isinstance
(
teacher_enc_outputs
,
tuple
),
type
(
teacher_enc_outputs
)
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
tloss
,
tlogits
,
tdec_hidden
,
_
=
self
.
teacher
(
outputs
=
self
.
teacher
(
input_ids
,
input_ids
,
attention_mask
=
src_mask
,
attention_mask
=
src_mask
,
encoder_outputs
=
teacher_
enc_outputs
,
encoder_outputs
=
(
enc_outputs
,
),
decoder_input_ids
=
decoder_input_ids
,
decoder_input_ids
=
decoder_input_ids
,
lm_labels
=
labels
,
lm_labels
=
labels
,
output_hidden_states
=
True
,
output_hidden_states
=
True
,
return_dict
=
True
,
)
)
tlogits
,
tdec_hidden
=
outputs
.
logits
,
outputs
.
decoder_hidden_states
dec_mask
=
decoder_input_ids
.
ne
(
pad_token_id
)
dec_mask
=
decoder_input_ids
.
ne
(
pad_token_id
)
loss_ce
=
self
.
calc_ce_loss
(
dec_mask
,
lm_logits
,
tlogits
)
loss_ce
=
self
.
calc_ce_loss
(
dec_mask
,
lm_logits
,
tlogits
)
if
self
.
alpha_hid
>
0
:
# Intermediate supervision of decoder hidden states
if
self
.
alpha_hid
>
0
:
# Intermediate supervision of decoder hidden states
...
@@ -180,10 +192,9 @@ class BartSummarizationDistiller(SummarizationModule):
...
@@ -180,10 +192,9 @@ class BartSummarizationDistiller(SummarizationModule):
blended_loss
=
(
blended_loss
=
(
self
.
alpha_ce
*
loss_ce
self
.
alpha_ce
*
loss_ce
+
self
.
alpha_mlm
*
student_lm_loss
+
self
.
alpha_mlm
*
student_lm_loss
+
self
.
hparams
.
alpha_encoder_loss
*
loss_encoder
+
self
.
hparams
.
alpha_hid
*
(
hid_loss_enc
+
hid_loss_dec
)
+
self
.
hparams
.
alpha_hid
*
(
hid_loss_enc
+
hid_loss_dec
)
)
)
return
blended_loss
,
loss_ce
,
student_lm_loss
,
loss_encoder
,
hid_loss_enc
,
hid_loss_dec
return
blended_loss
,
loss_ce
,
student_lm_loss
,
hid_loss_enc
,
hid_loss_dec
@
staticmethod
@
staticmethod
def
calc_hidden_loss
(
attention_mask
,
hidden_states
,
hidden_states_T
,
matches
,
normalize_hidden
):
def
calc_hidden_loss
(
attention_mask
,
hidden_states
,
hidden_states_T
,
matches
,
normalize_hidden
):
...
@@ -207,7 +218,6 @@ def add_distill_args(parser):
...
@@ -207,7 +218,6 @@ def add_distill_args(parser):
parser
.
add_argument
(
"--teacher"
,
type
=
str
)
parser
.
add_argument
(
"--teacher"
,
type
=
str
)
parser
.
add_argument
(
"--alpha_ce"
,
default
=
0.8
,
type
=
float
)
parser
.
add_argument
(
"--alpha_ce"
,
default
=
0.8
,
type
=
float
)
parser
.
add_argument
(
"--alpha_mlm"
,
default
=
0.2
,
type
=
float
)
parser
.
add_argument
(
"--alpha_mlm"
,
default
=
0.2
,
type
=
float
)
parser
.
add_argument
(
"--alpha_encoder_loss"
,
default
=
0.0
,
type
=
float
)
parser
.
add_argument
(
"--alpha_hid"
,
default
=
0.0
,
type
=
float
,
required
=
False
)
parser
.
add_argument
(
"--alpha_hid"
,
default
=
0.0
,
type
=
float
,
required
=
False
)
parser
.
add_argument
(
"--student_decoder_layers"
,
default
=
12
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--student_decoder_layers"
,
default
=
12
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--student_encoder_layers"
,
default
=
12
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--student_encoder_layers"
,
default
=
12
,
type
=
int
,
required
=
False
)
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
d5d2744a
...
@@ -86,7 +86,6 @@ CHEAP_ARGS = {
...
@@ -86,7 +86,6 @@ CHEAP_ARGS = {
"n_val"
:
-
1
,
"n_val"
:
-
1
,
"n_test"
:
-
1
,
"n_test"
:
-
1
,
"student_encoder_layers"
:
1
,
"student_encoder_layers"
:
1
,
"alpha_encoder_loss"
:
0.0
,
"freeze_encoder"
:
False
,
"freeze_encoder"
:
False
,
"auto_scale_batch_size"
:
False
,
"auto_scale_batch_size"
:
False
,
}
}
...
@@ -230,7 +229,6 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -230,7 +229,6 @@ class TestSummarizationDistiller(unittest.TestCase):
evaluate_checkpoint
(
ckpts
[
0
],
dest_dir
=
Path
(
tempfile
.
mkdtemp
()))
evaluate_checkpoint
(
ckpts
[
0
],
dest_dir
=
Path
(
tempfile
.
mkdtemp
()))
@
unittest
.
skip
(
"T5 distillation is broken at the moment"
)
def
test_distill_t5
(
self
):
def
test_distill_t5
(
self
):
updates
=
dict
(
updates
=
dict
(
student_encoder_layers
=
1
,
student_encoder_layers
=
1
,
...
@@ -255,7 +253,6 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -255,7 +253,6 @@ class TestSummarizationDistiller(unittest.TestCase):
model_name_or_path
=
"sshleifer/tinier_bart"
,
model_name_or_path
=
"sshleifer/tinier_bart"
,
teacher
=
CHEAP_ARGS
[
"model_name_or_path"
],
teacher
=
CHEAP_ARGS
[
"model_name_or_path"
],
val_check_interval
=
0.5
,
val_check_interval
=
0.5
,
alpha_encoder_loss
=
0.4
,
)
)
default_updates
.
update
(
updates
)
default_updates
.
update
(
updates
)
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment