Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d5d2744a
"docs/img/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "4784cc6c1abb1a15b8a73d1022e191ccc26272e9"
Unverified
Commit
d5d2744a
authored
Oct 05, 2020
by
Sam Shleifer
Committed by
GitHub
Oct 05, 2020
Browse files
Support T5 Distillation w/hidden state supervision (#7599)
parent
818c294f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
30 deletions
+37
-30
examples/seq2seq/distillation.py
examples/seq2seq/distillation.py
+37
-27
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+0
-3
No files found.
examples/seq2seq/distillation.py
View file @
d5d2744a
...
@@ -28,7 +28,7 @@ from lightning_base import generic_train # noqa
...
@@ -28,7 +28,7 @@ from lightning_base import generic_train # noqa
class
BartSummarizationDistiller
(
SummarizationModule
):
class
BartSummarizationDistiller
(
SummarizationModule
):
"""Supports Bart, Pegasus and other models that inherit from Bart."""
"""Supports Bart, Pegasus and other models that inherit from Bart."""
loss_names
=
[
"loss"
,
"ce_loss"
,
"mlm_loss"
,
"enc_mse_loss"
,
"hid_loss_enc"
,
"hid_loss_dec"
]
loss_names
=
[
"loss"
,
"ce_loss"
,
"mlm_loss"
,
"hid_loss_enc"
,
"hid_loss_dec"
]
def
__init__
(
self
,
hparams
):
def
__init__
(
self
,
hparams
):
assert
Path
(
hparams
.
data_dir
).
exists
()
assert
Path
(
hparams
.
data_dir
).
exists
()
...
@@ -46,9 +46,19 @@ class BartSummarizationDistiller(SummarizationModule):
...
@@ -46,9 +46,19 @@ class BartSummarizationDistiller(SummarizationModule):
if
hparams
.
length_penalty
!=
-
1
:
if
hparams
.
length_penalty
!=
-
1
:
student
.
config
.
length_penalty
=
hparams
.
length_penalty
student
.
config
.
length_penalty
=
hparams
.
length_penalty
super
().
__init__
(
hparams
,
model
=
student
,
config
=
student
.
config
)
super
().
__init__
(
hparams
,
model
=
student
,
config
=
student
.
config
)
model_type
=
student
.
config
.
model_type
self
.
e_layer_ids
,
self
.
d_layer_ids
=
e_layer_ids
,
d_layer_ids
# type: List[int], List[int]
self
.
e_layer_ids
,
self
.
d_layer_ids
=
e_layer_ids
,
d_layer_ids
# type: List[int], List[int]
self
.
different_encoder
=
hparams
.
student_encoder_layers
!=
teacher
.
config
.
encoder_layers
self
.
different_decoder
=
hparams
.
student_decoder_layers
!=
teacher
.
config
.
decoder_layers
if
model_type
==
"t5"
:
teacher_encoder_layers
=
len
(
teacher
.
get_encoder
().
block
)
teacher_decoder_layers
=
len
(
teacher
.
get_decoder
().
block
)
else
:
teacher_encoder_layers
=
teacher
.
config
.
encoder_layers
teacher_decoder_layers
=
teacher
.
config
.
decoder_layers
self
.
different_encoder
=
hparams
.
student_encoder_layers
!=
teacher_encoder_layers
self
.
different_decoder
=
hparams
.
student_decoder_layers
!=
teacher_decoder_layers
self
.
teacher
=
teacher
self
.
teacher
=
teacher
freeze_params
(
self
.
teacher
)
freeze_params
(
self
.
teacher
)
...
@@ -59,17 +69,17 @@ class BartSummarizationDistiller(SummarizationModule):
...
@@ -59,17 +69,17 @@ class BartSummarizationDistiller(SummarizationModule):
del
self
.
teacher
.
encoder
del
self
.
teacher
.
encoder
# Intermediate supervision: Decide which layers to supervise
# Intermediate supervision: Decide which layers to supervise
if
hparams
.
supervise_forward
:
if
hparams
.
supervise_forward
:
self
.
d
_matches
=
get_layers_to_supervise
(
self
.
e
_matches
=
get_layers_to_supervise
(
n_student
=
len
(
self
.
e_layer_ids
),
n_teacher
=
teacher_encoder_layers
)
n_student
=
len
(
self
.
d_layer_ids
),
n_teacher
=
self
.
teacher
.
config
.
decoder_layers
self
.
d_matches
=
get_layers_to_supervise
(
n_student
=
len
(
self
.
d_layer_ids
),
n_teacher
=
teacher
_
decoder_layers
)
)
else
:
# student layer should emulate hidden states of the teacher layer it was copied from
else
:
self
.
e_matches
=
self
.
e_layer_ids
self
.
d_matches
=
self
.
d_layer_ids
self
.
d_matches
=
self
.
d_layer_ids
self
.
ce_loss_fct
=
nn
.
KLDivLoss
(
reduction
=
"batchmean"
)
self
.
ce_loss_fct
=
nn
.
KLDivLoss
(
reduction
=
"batchmean"
)
self
.
temperature
=
2.0
self
.
temperature
=
2.0
self
.
alpha_mlm
=
hparams
.
alpha_mlm
self
.
alpha_mlm
=
hparams
.
alpha_mlm
self
.
alpha_ce
=
hparams
.
alpha_ce
self
.
alpha_ce
=
hparams
.
alpha_ce
self
.
alpha_hid
=
hparams
.
alpha_hid
self
.
alpha_hid
=
hparams
.
alpha_hid
self
.
alpha_encoder_loss
=
hparams
.
alpha_encoder_loss
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -129,7 +139,7 @@ class BartSummarizationDistiller(SummarizationModule):
...
@@ -129,7 +139,7 @@ class BartSummarizationDistiller(SummarizationModule):
output_hidden_states
=
True
,
output_hidden_states
=
True
,
output_attentions
=
False
,
output_attentions
=
False
,
use_cache
=
False
,
use_cache
=
False
,
)
# TODO(@sshleifer): return_dict=True cleanup
)
# Same cross entropy vs. label smoothing logic as finetune.py
# Same cross entropy vs. label smoothing logic as finetune.py
assert
lm_logits
.
shape
[
-
1
]
==
self
.
model
.
config
.
vocab_size
assert
lm_logits
.
shape
[
-
1
]
==
self
.
model
.
config
.
vocab_size
...
@@ -146,30 +156,32 @@ class BartSummarizationDistiller(SummarizationModule):
...
@@ -146,30 +156,32 @@ class BartSummarizationDistiller(SummarizationModule):
def
zero_tensor
():
def
zero_tensor
():
return
torch
.
tensor
(
0.0
).
type_as
(
student_lm_loss
)
return
torch
.
tensor
(
0.0
).
type_as
(
student_lm_loss
)
loss_encoder
,
hid_loss_enc
,
hid_loss_dec
=
zero_tensor
(),
zero_tensor
()
,
zero_tensor
()
hid_loss_enc
,
hid_loss_dec
=
zero_tensor
(),
zero_tensor
()
if
self
.
different_encoder
:
if
self
.
different_encoder
:
# compute encoder hidden state loss
with
torch
.
no_grad
():
with
torch
.
no_grad
():
teacher_enc_
outputs
,
teacher_enc_hid
,
_
=
self
.
teacher
.
get_encoder
()(
teacher_enc_
hid
=
self
.
teacher
.
get_encoder
()(
input_ids
,
attention_mask
=
src_mask
,
output_hidden_states
=
True
input_ids
,
attention_mask
=
src_mask
,
output_hidden_states
=
True
,
return_dict
=
True
)
)
.
hidden_states
# DEPRECATE THIS
if
self
.
hparams
.
alpha_encoder_loss
>
0
:
hid_loss_enc
=
self
.
calc_hidden_loss
(
loss_encoder
=
self
.
calc_mse_loss
(
enc_outputs
,
teacher_enc_outputs
,
src_mask
)
src_mask
,
enc_hidden_state
,
hid_loss_enc
=
self
.
calc_hidden_loss
(
src_mask
,
enc_hidden_state
,
teacher_enc_hid
,
self
.
e_layer_ids
)
teacher_enc_hid
,
self
.
e_matches
,
teacher_enc_outputs
=
(
enc_outputs
,)
normalize_hidden
=
self
.
hparams
.
normalize_hidden
,
assert
isinstance
(
teacher_enc_outputs
,
tuple
),
type
(
teacher_enc_outputs
)
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
tloss
,
tlogits
,
tdec_hidden
,
_
=
self
.
teacher
(
outputs
=
self
.
teacher
(
input_ids
,
input_ids
,
attention_mask
=
src_mask
,
attention_mask
=
src_mask
,
encoder_outputs
=
teacher_
enc_outputs
,
encoder_outputs
=
(
enc_outputs
,
),
decoder_input_ids
=
decoder_input_ids
,
decoder_input_ids
=
decoder_input_ids
,
lm_labels
=
labels
,
lm_labels
=
labels
,
output_hidden_states
=
True
,
output_hidden_states
=
True
,
return_dict
=
True
,
)
)
tlogits
,
tdec_hidden
=
outputs
.
logits
,
outputs
.
decoder_hidden_states
dec_mask
=
decoder_input_ids
.
ne
(
pad_token_id
)
dec_mask
=
decoder_input_ids
.
ne
(
pad_token_id
)
loss_ce
=
self
.
calc_ce_loss
(
dec_mask
,
lm_logits
,
tlogits
)
loss_ce
=
self
.
calc_ce_loss
(
dec_mask
,
lm_logits
,
tlogits
)
if
self
.
alpha_hid
>
0
:
# Intermediate supervision of decoder hidden states
if
self
.
alpha_hid
>
0
:
# Intermediate supervision of decoder hidden states
...
@@ -180,10 +192,9 @@ class BartSummarizationDistiller(SummarizationModule):
...
@@ -180,10 +192,9 @@ class BartSummarizationDistiller(SummarizationModule):
blended_loss
=
(
blended_loss
=
(
self
.
alpha_ce
*
loss_ce
self
.
alpha_ce
*
loss_ce
+
self
.
alpha_mlm
*
student_lm_loss
+
self
.
alpha_mlm
*
student_lm_loss
+
self
.
hparams
.
alpha_encoder_loss
*
loss_encoder
+
self
.
hparams
.
alpha_hid
*
(
hid_loss_enc
+
hid_loss_dec
)
+
self
.
hparams
.
alpha_hid
*
(
hid_loss_enc
+
hid_loss_dec
)
)
)
return
blended_loss
,
loss_ce
,
student_lm_loss
,
loss_encoder
,
hid_loss_enc
,
hid_loss_dec
return
blended_loss
,
loss_ce
,
student_lm_loss
,
hid_loss_enc
,
hid_loss_dec
@
staticmethod
@
staticmethod
def
calc_hidden_loss
(
attention_mask
,
hidden_states
,
hidden_states_T
,
matches
,
normalize_hidden
):
def
calc_hidden_loss
(
attention_mask
,
hidden_states
,
hidden_states_T
,
matches
,
normalize_hidden
):
...
@@ -207,7 +218,6 @@ def add_distill_args(parser):
...
@@ -207,7 +218,6 @@ def add_distill_args(parser):
parser
.
add_argument
(
"--teacher"
,
type
=
str
)
parser
.
add_argument
(
"--teacher"
,
type
=
str
)
parser
.
add_argument
(
"--alpha_ce"
,
default
=
0.8
,
type
=
float
)
parser
.
add_argument
(
"--alpha_ce"
,
default
=
0.8
,
type
=
float
)
parser
.
add_argument
(
"--alpha_mlm"
,
default
=
0.2
,
type
=
float
)
parser
.
add_argument
(
"--alpha_mlm"
,
default
=
0.2
,
type
=
float
)
parser
.
add_argument
(
"--alpha_encoder_loss"
,
default
=
0.0
,
type
=
float
)
parser
.
add_argument
(
"--alpha_hid"
,
default
=
0.0
,
type
=
float
,
required
=
False
)
parser
.
add_argument
(
"--alpha_hid"
,
default
=
0.0
,
type
=
float
,
required
=
False
)
parser
.
add_argument
(
"--student_decoder_layers"
,
default
=
12
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--student_decoder_layers"
,
default
=
12
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--student_encoder_layers"
,
default
=
12
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--student_encoder_layers"
,
default
=
12
,
type
=
int
,
required
=
False
)
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
d5d2744a
...
@@ -86,7 +86,6 @@ CHEAP_ARGS = {
...
@@ -86,7 +86,6 @@ CHEAP_ARGS = {
"n_val"
:
-
1
,
"n_val"
:
-
1
,
"n_test"
:
-
1
,
"n_test"
:
-
1
,
"student_encoder_layers"
:
1
,
"student_encoder_layers"
:
1
,
"alpha_encoder_loss"
:
0.0
,
"freeze_encoder"
:
False
,
"freeze_encoder"
:
False
,
"auto_scale_batch_size"
:
False
,
"auto_scale_batch_size"
:
False
,
}
}
...
@@ -230,7 +229,6 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -230,7 +229,6 @@ class TestSummarizationDistiller(unittest.TestCase):
evaluate_checkpoint
(
ckpts
[
0
],
dest_dir
=
Path
(
tempfile
.
mkdtemp
()))
evaluate_checkpoint
(
ckpts
[
0
],
dest_dir
=
Path
(
tempfile
.
mkdtemp
()))
@
unittest
.
skip
(
"T5 distillation is broken at the moment"
)
def
test_distill_t5
(
self
):
def
test_distill_t5
(
self
):
updates
=
dict
(
updates
=
dict
(
student_encoder_layers
=
1
,
student_encoder_layers
=
1
,
...
@@ -255,7 +253,6 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -255,7 +253,6 @@ class TestSummarizationDistiller(unittest.TestCase):
model_name_or_path
=
"sshleifer/tinier_bart"
,
model_name_or_path
=
"sshleifer/tinier_bart"
,
teacher
=
CHEAP_ARGS
[
"model_name_or_path"
],
teacher
=
CHEAP_ARGS
[
"model_name_or_path"
],
val_check_interval
=
0.5
,
val_check_interval
=
0.5
,
alpha_encoder_loss
=
0.4
,
)
)
default_updates
.
update
(
updates
)
default_updates
.
update
(
updates
)
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
args_d
:
dict
=
CHEAP_ARGS
.
copy
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment