Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6078b120
Unverified
Commit
6078b120
authored
Sep 04, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 04, 2020
Browse files
[s2s] distill: --normalize_hidden --supervise_forward (#6834)
parent
c5d43a87
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
45 additions
and
18 deletions
+45
-18
examples/seq2seq/distil_marian_enro_teacher.sh
examples/seq2seq/distil_marian_enro_teacher.sh
+3
-4
examples/seq2seq/distillation.py
examples/seq2seq/distillation.py
+40
-14
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+2
-0
No files found.
examples/seq2seq/distil_marian_enro_teacher.sh
View file @
6078b120
...
...
@@ -5,10 +5,9 @@ export WANDB_PROJECT=dmar
python distillation.py
\
--learning_rate
=
3e-4
\
--do_train
\
--do_predict
\
--fp16
\
--val_check_interval
0.25
\
--teacher
Helsinki-NLP/opus-mt-en-ro
--data_dir
$ENRO_DIR
\
--teacher
Helsinki-NLP/opus-mt-en-ro
\
--max_source_length
$MAX_LEN
--max_target_length
$MAX_LEN
--val_max_target_length
$MAX_LEN
--test_max_target_length
$MAX_LEN
\
--student_decoder_layers
3
--student_encoder_layers
6
\
--freeze_encoder
--freeze_embeds
\
...
...
@@ -16,6 +15,6 @@ python distillation.py \
--alpha_hid
=
3.
\
--train_batch_size
=
$BS
--eval_batch_size
=
$BS
\
--tokenizer_name
Helsinki-NLP/opus-mt-en-ro
\
--warmup_steps
500
--sortish_sampler
--logger_name
wandb
\
--gpus
1
--fp16_opt_level
O1
--task
translation
\
--warmup_steps
500
--logger_name
wandb
\
--fp16_opt_level
O1
--task
translation
--normalize_hidden
\
"
$@
"
examples/seq2seq/distillation.py
View file @
6078b120
...
...
@@ -87,10 +87,19 @@ class BartSummarizationDistiller(SummarizationModule):
}
if
hparams
.
length_penalty
!=
-
1
:
student_updates
[
"length_penalty"
]
=
hparams
.
length_penalty
d_layers_to_copy
:
List
=
get_layers_to_copy
(
student_updates
[
"decoder_layers"
],
teacher
.
config
.
decoder_layers
)
e_layers_to_copy
:
List
=
get_layers_to_copy
(
student_updates
[
"encoder_layers"
],
teacher
.
config
.
encoder_layers
)
hparams
.
d_layer_to_copy
=
d_layers_to_copy
hparams
.
e_layer_to_copy
=
e_layers_to_copy
d_layers_to_copy
:
List
=
get_layers_to_copy
(
student_updates
[
"decoder_layers"
],
teacher
.
config
.
decoder_layers
)
if
hparams
.
supervise_forward
:
hparams
.
d_matches
=
get_layers_to_supervise
(
student_updates
[
"decoder_layers"
],
teacher
.
config
.
decoder_layers
)
else
:
hparams
.
d_matches
=
d_layers_to_copy
hparams
.
d_layer_to_copy
=
d_layers_to_copy
kw
=
teacher
.
config
.
to_diff_dict
()
kw
.
update
(
student_updates
)
# Copy weights
...
...
@@ -221,7 +230,7 @@ class BartSummarizationDistiller(SummarizationModule):
dec_mask
=
decoder_input_ids
.
ne
(
pad_token_id
)
loss_ce
,
s_logits_slct
,
t_logits_slct
=
self
.
calc_ce_loss
(
dec_mask
,
lm_logits
,
tlogits
)
if
self
.
alpha_hid
>
0
:
hid_loss_dec
=
self
.
calc_hidden_loss
(
dec_mask
,
dec_hidden
,
tdec_hidden
,
self
.
hparams
.
d_
layer_to_copy
)
hid_loss_dec
=
self
.
calc_hidden_loss
(
dec_mask
,
dec_hidden
,
tdec_hidden
,
self
.
hparams
.
d_
matches
)
blended_loss
=
(
self
.
alpha_ce
*
loss_ce
...
...
@@ -237,12 +246,14 @@ class BartSummarizationDistiller(SummarizationModule):
assert
not
isinstance
(
hidden_states_T
,
torch
.
Tensor
),
f
"
{
msg
}{
hidden_states_T
.
shape
}
"
mask
=
attention_mask
.
to
(
hidden_states
[
0
])
valid_count
=
mask
.
sum
()
*
hidden_states
[
0
].
size
(
-
1
)
hidden_losses
=
[
(
F
.
mse_loss
(
hidden_states
[
i
],
hidden_states_T
[
j
],
reduction
=
"none"
)
*
mask
.
unsqueeze
(
-
1
)).
sum
()
/
valid_count
for
i
,
j
in
enumerate
(
matches
)
]
return
sum
(
hidden_losses
)
student_states
=
torch
.
stack
([
hidden_states
[
i
]
for
i
in
range
(
len
(
matches
))])
teacher_states
=
torch
.
stack
([
hidden_states_T
[
j
]
for
j
in
matches
])
if
self
.
hparams
.
normalize_hidden
:
student_states
=
F
.
layer_norm
(
student_states
,
student_states
.
shape
[
1
:])
teacher_states
=
F
.
layer_norm
(
teacher_states
,
teacher_states
.
shape
[
1
:])
mse
=
F
.
mse_loss
(
student_states
,
teacher_states
,
reduction
=
"none"
)
masked_mse
=
(
mse
*
mask
.
unsqueeze
(
0
).
unsqueeze
(
-
1
)).
sum
()
/
valid_count
return
masked_mse
def
add_distill_args
(
parser
):
...
...
@@ -255,6 +266,8 @@ def add_distill_args(parser):
parser
.
add_argument
(
"--student_encoder_layers"
,
default
=
12
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--no_teacher"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--length_penalty"
,
type
=
float
,
default
=-
1
)
parser
.
add_argument
(
"--supervise_forward"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--normalize_hidden"
,
action
=
"store_true"
,
default
=
False
)
class
BartTranslationDistiller
(
BartSummarizationDistiller
):
...
...
@@ -389,7 +402,7 @@ class T5SummarizationDistiller(BartSummarizationDistiller):
loss_ce
,
s_logits_slct
,
t_logits_slct
=
self
.
calc_ce_loss
(
dec_mask
,
slogits
,
tlogits
)
if
self
.
alpha_hid
>
0
:
hid_loss_dec
=
self
.
calc_hidden_loss
(
dec_mask
,
dec_hidden
,
tdec_hidden
,
self
.
hparams
.
d_
layer_to_copy
)
hid_loss_dec
=
self
.
calc_hidden_loss
(
dec_mask
,
dec_hidden
,
tdec_hidden
,
self
.
hparams
.
d_
matches
)
blended_loss
=
(
self
.
alpha_ce
*
loss_ce
...
...
@@ -463,15 +476,28 @@ LAYERS_TO_COPY = {
},
6
:
{
1
:
[
0
],
2
:
[
0
,
5
],
3
:
[
0
,
2
,
5
],
4
:
[
0
,
1
,
3
,
5
],
6
:
list
(
range
(
6
))},
}
LAYERS_TO_SUPERVISE
=
{
12
:
{
1
:
[
11
],
2
:
[
5
,
11
],
3
:
[
3
,
7
,
11
],
6
:
[
1
,
3
,
5
,
8
,
10
,
11
]},
16
:
{
1
:
[
15
],
4
:
[
4
,
9
,
12
,
15
],
8
:
[
1
,
3
,
5
,
7
,
9
,
11
,
13
,
15
]},
6
:
{
1
:
[
5
],
2
:
[
3
,
5
],
3
:
[
1
,
4
,
5
],
4
:
[
1
,
2
,
4
,
5
]},
2
:
{
1
:
[
1
],
2
:
[
0
,
1
]},
}
def
get_layers_to_supervise
(
n_student
,
n_teacher
):
return
LAYERS_TO_SUPERVISE
[
n_teacher
][
n_student
]
def
get_layers_to_copy
(
n_student
,
n_teacher
):
try
:
return
LAYERS_TO_COPY
[
n_teacher
][
n_student
]
val
=
LAYERS_TO_COPY
[
n_teacher
][
n_student
]
assert
len
(
LAYERS_TO_SUPERVISE
[
n_teacher
][
n_student
])
==
len
(
val
)
==
n_student
return
val
except
KeyError
:
warnings
.
warn
(
f
"no hardcoded layers to copy for teacher
{
n_teacher
}
-> student
{
n_student
}
, defaulting to first
{
n_student
}
"
)
if
n_student
!=
n_teacher
:
warnings
.
warn
(
f
"no hardcoded layers to copy for teacher
{
n_teacher
}
-> student
{
n_student
}
, defaulting to first
{
n_student
}
"
)
return
list
(
range
(
n_student
))
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
6078b120
...
...
@@ -31,6 +31,8 @@ logging.basicConfig(level=logging.DEBUG)
logger
=
logging
.
getLogger
()
CUDA_AVAILABLE
=
torch
.
cuda
.
is_available
()
CHEAP_ARGS
=
{
"supervise_forward"
:
True
,
"normalize_hidden"
:
True
,
"label_smoothing"
:
0.2
,
"eval_beams"
:
1
,
"val_metric"
:
"loss"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment