Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6078b120
Unverified
Commit
6078b120
authored
Sep 04, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 04, 2020
Browse files
[s2s] distill: --normalize_hidden --supervise_forward (#6834)
parent
c5d43a87
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
45 additions
and
18 deletions
+45
-18
examples/seq2seq/distil_marian_enro_teacher.sh
examples/seq2seq/distil_marian_enro_teacher.sh
+3
-4
examples/seq2seq/distillation.py
examples/seq2seq/distillation.py
+40
-14
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+2
-0
No files found.
examples/seq2seq/distil_marian_enro_teacher.sh
View file @
6078b120
...
...
@@ -5,10 +5,9 @@ export WANDB_PROJECT=dmar
python distillation.py
\
--learning_rate
=
3e-4
\
--do_train
\
--do_predict
\
--fp16
\
--val_check_interval
0.25
\
--teacher
Helsinki-NLP/opus-mt-en-ro
--data_dir
$ENRO_DIR
\
--teacher
Helsinki-NLP/opus-mt-en-ro
\
--max_source_length
$MAX_LEN
--max_target_length
$MAX_LEN
--val_max_target_length
$MAX_LEN
--test_max_target_length
$MAX_LEN
\
--student_decoder_layers
3
--student_encoder_layers
6
\
--freeze_encoder
--freeze_embeds
\
...
...
@@ -16,6 +15,6 @@ python distillation.py \
--alpha_hid
=
3.
\
--train_batch_size
=
$BS
--eval_batch_size
=
$BS
\
--tokenizer_name
Helsinki-NLP/opus-mt-en-ro
\
--warmup_steps
500
--sortish_sampler
--logger_name
wandb
\
--gpus
1
--fp16_opt_level
O1
--task
translation
\
--warmup_steps
500
--logger_name
wandb
\
--fp16_opt_level
O1
--task
translation
--normalize_hidden
\
"
$@
"
examples/seq2seq/distillation.py
View file @
6078b120
...
...
@@ -87,10 +87,19 @@ class BartSummarizationDistiller(SummarizationModule):
}
if
hparams
.
length_penalty
!=
-
1
:
student_updates
[
"length_penalty"
]
=
hparams
.
length_penalty
d_layers_to_copy
:
List
=
get_layers_to_copy
(
student_updates
[
"decoder_layers"
],
teacher
.
config
.
decoder_layers
)
e_layers_to_copy
:
List
=
get_layers_to_copy
(
student_updates
[
"encoder_layers"
],
teacher
.
config
.
encoder_layers
)
hparams
.
d_layer_to_copy
=
d_layers_to_copy
hparams
.
e_layer_to_copy
=
e_layers_to_copy
d_layers_to_copy
:
List
=
get_layers_to_copy
(
student_updates
[
"decoder_layers"
],
teacher
.
config
.
decoder_layers
)
if
hparams
.
supervise_forward
:
hparams
.
d_matches
=
get_layers_to_supervise
(
student_updates
[
"decoder_layers"
],
teacher
.
config
.
decoder_layers
)
else
:
hparams
.
d_matches
=
d_layers_to_copy
hparams
.
d_layer_to_copy
=
d_layers_to_copy
kw
=
teacher
.
config
.
to_diff_dict
()
kw
.
update
(
student_updates
)
# Copy weights
...
...
@@ -221,7 +230,7 @@ class BartSummarizationDistiller(SummarizationModule):
dec_mask
=
decoder_input_ids
.
ne
(
pad_token_id
)
loss_ce
,
s_logits_slct
,
t_logits_slct
=
self
.
calc_ce_loss
(
dec_mask
,
lm_logits
,
tlogits
)
if
self
.
alpha_hid
>
0
:
hid_loss_dec
=
self
.
calc_hidden_loss
(
dec_mask
,
dec_hidden
,
tdec_hidden
,
self
.
hparams
.
d_
layer_to_copy
)
hid_loss_dec
=
self
.
calc_hidden_loss
(
dec_mask
,
dec_hidden
,
tdec_hidden
,
self
.
hparams
.
d_
matches
)
blended_loss
=
(
self
.
alpha_ce
*
loss_ce
...
...
@@ -237,12 +246,14 @@ class BartSummarizationDistiller(SummarizationModule):
assert
not
isinstance
(
hidden_states_T
,
torch
.
Tensor
),
f
"
{
msg
}{
hidden_states_T
.
shape
}
"
mask
=
attention_mask
.
to
(
hidden_states
[
0
])
valid_count
=
mask
.
sum
()
*
hidden_states
[
0
].
size
(
-
1
)
hidden_losses
=
[
(
F
.
mse_loss
(
hidden_states
[
i
],
hidden_states_T
[
j
],
reduction
=
"none"
)
*
mask
.
unsqueeze
(
-
1
)).
sum
()
/
valid_count
for
i
,
j
in
enumerate
(
matches
)
]
return
sum
(
hidden_losses
)
student_states
=
torch
.
stack
([
hidden_states
[
i
]
for
i
in
range
(
len
(
matches
))])
teacher_states
=
torch
.
stack
([
hidden_states_T
[
j
]
for
j
in
matches
])
if
self
.
hparams
.
normalize_hidden
:
student_states
=
F
.
layer_norm
(
student_states
,
student_states
.
shape
[
1
:])
teacher_states
=
F
.
layer_norm
(
teacher_states
,
teacher_states
.
shape
[
1
:])
mse
=
F
.
mse_loss
(
student_states
,
teacher_states
,
reduction
=
"none"
)
masked_mse
=
(
mse
*
mask
.
unsqueeze
(
0
).
unsqueeze
(
-
1
)).
sum
()
/
valid_count
return
masked_mse
def
add_distill_args
(
parser
):
...
...
@@ -255,6 +266,8 @@ def add_distill_args(parser):
parser
.
add_argument
(
"--student_encoder_layers"
,
default
=
12
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--no_teacher"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--length_penalty"
,
type
=
float
,
default
=-
1
)
parser
.
add_argument
(
"--supervise_forward"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--normalize_hidden"
,
action
=
"store_true"
,
default
=
False
)
class
BartTranslationDistiller
(
BartSummarizationDistiller
):
...
...
@@ -389,7 +402,7 @@ class T5SummarizationDistiller(BartSummarizationDistiller):
loss_ce
,
s_logits_slct
,
t_logits_slct
=
self
.
calc_ce_loss
(
dec_mask
,
slogits
,
tlogits
)
if
self
.
alpha_hid
>
0
:
hid_loss_dec
=
self
.
calc_hidden_loss
(
dec_mask
,
dec_hidden
,
tdec_hidden
,
self
.
hparams
.
d_
layer_to_copy
)
hid_loss_dec
=
self
.
calc_hidden_loss
(
dec_mask
,
dec_hidden
,
tdec_hidden
,
self
.
hparams
.
d_
matches
)
blended_loss
=
(
self
.
alpha_ce
*
loss_ce
...
...
@@ -463,12 +476,25 @@ LAYERS_TO_COPY = {
},
6
:
{
1
:
[
0
],
2
:
[
0
,
5
],
3
:
[
0
,
2
,
5
],
4
:
[
0
,
1
,
3
,
5
],
6
:
list
(
range
(
6
))},
}
LAYERS_TO_SUPERVISE
=
{
12
:
{
1
:
[
11
],
2
:
[
5
,
11
],
3
:
[
3
,
7
,
11
],
6
:
[
1
,
3
,
5
,
8
,
10
,
11
]},
16
:
{
1
:
[
15
],
4
:
[
4
,
9
,
12
,
15
],
8
:
[
1
,
3
,
5
,
7
,
9
,
11
,
13
,
15
]},
6
:
{
1
:
[
5
],
2
:
[
3
,
5
],
3
:
[
1
,
4
,
5
],
4
:
[
1
,
2
,
4
,
5
]},
2
:
{
1
:
[
1
],
2
:
[
0
,
1
]},
}
def
get_layers_to_supervise
(
n_student
,
n_teacher
):
return
LAYERS_TO_SUPERVISE
[
n_teacher
][
n_student
]
def
get_layers_to_copy
(
n_student
,
n_teacher
):
try
:
return
LAYERS_TO_COPY
[
n_teacher
][
n_student
]
val
=
LAYERS_TO_COPY
[
n_teacher
][
n_student
]
assert
len
(
LAYERS_TO_SUPERVISE
[
n_teacher
][
n_student
])
==
len
(
val
)
==
n_student
return
val
except
KeyError
:
if
n_student
!=
n_teacher
:
warnings
.
warn
(
f
"no hardcoded layers to copy for teacher
{
n_teacher
}
-> student
{
n_student
}
, defaulting to first
{
n_student
}
"
)
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
6078b120
...
...
@@ -31,6 +31,8 @@ logging.basicConfig(level=logging.DEBUG)
logger
=
logging
.
getLogger
()
CUDA_AVAILABLE
=
torch
.
cuda
.
is_available
()
CHEAP_ARGS
=
{
"supervise_forward"
:
True
,
"normalize_hidden"
:
True
,
"label_smoothing"
:
0.2
,
"eval_beams"
:
1
,
"val_metric"
:
"loss"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment