Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6078b120
Unverified
Commit
6078b120
authored
Sep 04, 2020
by
Sam Shleifer
Committed by
GitHub
Sep 04, 2020
Browse files
[s2s] distill: --normalize_hidden --supervise_forward (#6834)
parent
c5d43a87
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
45 additions
and
18 deletions
+45
-18
examples/seq2seq/distil_marian_enro_teacher.sh
examples/seq2seq/distil_marian_enro_teacher.sh
+3
-4
examples/seq2seq/distillation.py
examples/seq2seq/distillation.py
+40
-14
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+2
-0
No files found.
examples/seq2seq/distil_marian_enro_teacher.sh
View file @
6078b120
...
@@ -5,10 +5,9 @@ export WANDB_PROJECT=dmar
...
@@ -5,10 +5,9 @@ export WANDB_PROJECT=dmar
python distillation.py
\
python distillation.py
\
--learning_rate
=
3e-4
\
--learning_rate
=
3e-4
\
--do_train
\
--do_train
\
--do_predict
\
--fp16
\
--fp16
\
--val_check_interval
0.25
\
--val_check_interval
0.25
\
--teacher
Helsinki-NLP/opus-mt-en-ro
--data_dir
$ENRO_DIR
\
--teacher
Helsinki-NLP/opus-mt-en-ro
\
--max_source_length
$MAX_LEN
--max_target_length
$MAX_LEN
--val_max_target_length
$MAX_LEN
--test_max_target_length
$MAX_LEN
\
--max_source_length
$MAX_LEN
--max_target_length
$MAX_LEN
--val_max_target_length
$MAX_LEN
--test_max_target_length
$MAX_LEN
\
--student_decoder_layers
3
--student_encoder_layers
6
\
--student_decoder_layers
3
--student_encoder_layers
6
\
--freeze_encoder
--freeze_embeds
\
--freeze_encoder
--freeze_embeds
\
...
@@ -16,6 +15,6 @@ python distillation.py \
...
@@ -16,6 +15,6 @@ python distillation.py \
--alpha_hid
=
3.
\
--alpha_hid
=
3.
\
--train_batch_size
=
$BS
--eval_batch_size
=
$BS
\
--train_batch_size
=
$BS
--eval_batch_size
=
$BS
\
--tokenizer_name
Helsinki-NLP/opus-mt-en-ro
\
--tokenizer_name
Helsinki-NLP/opus-mt-en-ro
\
--warmup_steps
500
--sortish_sampler
--logger_name
wandb
\
--warmup_steps
500
--logger_name
wandb
\
--gpus
1
--fp16_opt_level
O1
--task
translation
\
--fp16_opt_level
O1
--task
translation
--normalize_hidden
\
"
$@
"
"
$@
"
examples/seq2seq/distillation.py
View file @
6078b120
...
@@ -87,10 +87,19 @@ class BartSummarizationDistiller(SummarizationModule):
...
@@ -87,10 +87,19 @@ class BartSummarizationDistiller(SummarizationModule):
}
}
if
hparams
.
length_penalty
!=
-
1
:
if
hparams
.
length_penalty
!=
-
1
:
student_updates
[
"length_penalty"
]
=
hparams
.
length_penalty
student_updates
[
"length_penalty"
]
=
hparams
.
length_penalty
d_layers_to_copy
:
List
=
get_layers_to_copy
(
student_updates
[
"decoder_layers"
],
teacher
.
config
.
decoder_layers
)
e_layers_to_copy
:
List
=
get_layers_to_copy
(
student_updates
[
"encoder_layers"
],
teacher
.
config
.
encoder_layers
)
e_layers_to_copy
:
List
=
get_layers_to_copy
(
student_updates
[
"encoder_layers"
],
teacher
.
config
.
encoder_layers
)
hparams
.
d_layer_to_copy
=
d_layers_to_copy
hparams
.
e_layer_to_copy
=
e_layers_to_copy
hparams
.
e_layer_to_copy
=
e_layers_to_copy
d_layers_to_copy
:
List
=
get_layers_to_copy
(
student_updates
[
"decoder_layers"
],
teacher
.
config
.
decoder_layers
)
if
hparams
.
supervise_forward
:
hparams
.
d_matches
=
get_layers_to_supervise
(
student_updates
[
"decoder_layers"
],
teacher
.
config
.
decoder_layers
)
else
:
hparams
.
d_matches
=
d_layers_to_copy
hparams
.
d_layer_to_copy
=
d_layers_to_copy
kw
=
teacher
.
config
.
to_diff_dict
()
kw
=
teacher
.
config
.
to_diff_dict
()
kw
.
update
(
student_updates
)
kw
.
update
(
student_updates
)
# Copy weights
# Copy weights
...
@@ -221,7 +230,7 @@ class BartSummarizationDistiller(SummarizationModule):
...
@@ -221,7 +230,7 @@ class BartSummarizationDistiller(SummarizationModule):
dec_mask
=
decoder_input_ids
.
ne
(
pad_token_id
)
dec_mask
=
decoder_input_ids
.
ne
(
pad_token_id
)
loss_ce
,
s_logits_slct
,
t_logits_slct
=
self
.
calc_ce_loss
(
dec_mask
,
lm_logits
,
tlogits
)
loss_ce
,
s_logits_slct
,
t_logits_slct
=
self
.
calc_ce_loss
(
dec_mask
,
lm_logits
,
tlogits
)
if
self
.
alpha_hid
>
0
:
if
self
.
alpha_hid
>
0
:
hid_loss_dec
=
self
.
calc_hidden_loss
(
dec_mask
,
dec_hidden
,
tdec_hidden
,
self
.
hparams
.
d_
layer_to_copy
)
hid_loss_dec
=
self
.
calc_hidden_loss
(
dec_mask
,
dec_hidden
,
tdec_hidden
,
self
.
hparams
.
d_
matches
)
blended_loss
=
(
blended_loss
=
(
self
.
alpha_ce
*
loss_ce
self
.
alpha_ce
*
loss_ce
...
@@ -237,12 +246,14 @@ class BartSummarizationDistiller(SummarizationModule):
...
@@ -237,12 +246,14 @@ class BartSummarizationDistiller(SummarizationModule):
assert
not
isinstance
(
hidden_states_T
,
torch
.
Tensor
),
f
"
{
msg
}{
hidden_states_T
.
shape
}
"
assert
not
isinstance
(
hidden_states_T
,
torch
.
Tensor
),
f
"
{
msg
}{
hidden_states_T
.
shape
}
"
mask
=
attention_mask
.
to
(
hidden_states
[
0
])
mask
=
attention_mask
.
to
(
hidden_states
[
0
])
valid_count
=
mask
.
sum
()
*
hidden_states
[
0
].
size
(
-
1
)
valid_count
=
mask
.
sum
()
*
hidden_states
[
0
].
size
(
-
1
)
hidden_losses
=
[
student_states
=
torch
.
stack
([
hidden_states
[
i
]
for
i
in
range
(
len
(
matches
))])
(
F
.
mse_loss
(
hidden_states
[
i
],
hidden_states_T
[
j
],
reduction
=
"none"
)
*
mask
.
unsqueeze
(
-
1
)).
sum
()
teacher_states
=
torch
.
stack
([
hidden_states_T
[
j
]
for
j
in
matches
])
/
valid_count
if
self
.
hparams
.
normalize_hidden
:
for
i
,
j
in
enumerate
(
matches
)
student_states
=
F
.
layer_norm
(
student_states
,
student_states
.
shape
[
1
:])
]
teacher_states
=
F
.
layer_norm
(
teacher_states
,
teacher_states
.
shape
[
1
:])
return
sum
(
hidden_losses
)
mse
=
F
.
mse_loss
(
student_states
,
teacher_states
,
reduction
=
"none"
)
masked_mse
=
(
mse
*
mask
.
unsqueeze
(
0
).
unsqueeze
(
-
1
)).
sum
()
/
valid_count
return
masked_mse
def
add_distill_args
(
parser
):
def
add_distill_args
(
parser
):
...
@@ -255,6 +266,8 @@ def add_distill_args(parser):
...
@@ -255,6 +266,8 @@ def add_distill_args(parser):
parser
.
add_argument
(
"--student_encoder_layers"
,
default
=
12
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--student_encoder_layers"
,
default
=
12
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--no_teacher"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--no_teacher"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--length_penalty"
,
type
=
float
,
default
=-
1
)
parser
.
add_argument
(
"--length_penalty"
,
type
=
float
,
default
=-
1
)
parser
.
add_argument
(
"--supervise_forward"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--normalize_hidden"
,
action
=
"store_true"
,
default
=
False
)
class
BartTranslationDistiller
(
BartSummarizationDistiller
):
class
BartTranslationDistiller
(
BartSummarizationDistiller
):
...
@@ -389,7 +402,7 @@ class T5SummarizationDistiller(BartSummarizationDistiller):
...
@@ -389,7 +402,7 @@ class T5SummarizationDistiller(BartSummarizationDistiller):
loss_ce
,
s_logits_slct
,
t_logits_slct
=
self
.
calc_ce_loss
(
dec_mask
,
slogits
,
tlogits
)
loss_ce
,
s_logits_slct
,
t_logits_slct
=
self
.
calc_ce_loss
(
dec_mask
,
slogits
,
tlogits
)
if
self
.
alpha_hid
>
0
:
if
self
.
alpha_hid
>
0
:
hid_loss_dec
=
self
.
calc_hidden_loss
(
dec_mask
,
dec_hidden
,
tdec_hidden
,
self
.
hparams
.
d_
layer_to_copy
)
hid_loss_dec
=
self
.
calc_hidden_loss
(
dec_mask
,
dec_hidden
,
tdec_hidden
,
self
.
hparams
.
d_
matches
)
blended_loss
=
(
blended_loss
=
(
self
.
alpha_ce
*
loss_ce
self
.
alpha_ce
*
loss_ce
...
@@ -463,15 +476,28 @@ LAYERS_TO_COPY = {
...
@@ -463,15 +476,28 @@ LAYERS_TO_COPY = {
},
},
6
:
{
1
:
[
0
],
2
:
[
0
,
5
],
3
:
[
0
,
2
,
5
],
4
:
[
0
,
1
,
3
,
5
],
6
:
list
(
range
(
6
))},
6
:
{
1
:
[
0
],
2
:
[
0
,
5
],
3
:
[
0
,
2
,
5
],
4
:
[
0
,
1
,
3
,
5
],
6
:
list
(
range
(
6
))},
}
}
LAYERS_TO_SUPERVISE
=
{
12
:
{
1
:
[
11
],
2
:
[
5
,
11
],
3
:
[
3
,
7
,
11
],
6
:
[
1
,
3
,
5
,
8
,
10
,
11
]},
16
:
{
1
:
[
15
],
4
:
[
4
,
9
,
12
,
15
],
8
:
[
1
,
3
,
5
,
7
,
9
,
11
,
13
,
15
]},
6
:
{
1
:
[
5
],
2
:
[
3
,
5
],
3
:
[
1
,
4
,
5
],
4
:
[
1
,
2
,
4
,
5
]},
2
:
{
1
:
[
1
],
2
:
[
0
,
1
]},
}
def
get_layers_to_supervise
(
n_student
,
n_teacher
):
return
LAYERS_TO_SUPERVISE
[
n_teacher
][
n_student
]
def
get_layers_to_copy
(
n_student
,
n_teacher
):
def
get_layers_to_copy
(
n_student
,
n_teacher
):
try
:
try
:
return
LAYERS_TO_COPY
[
n_teacher
][
n_student
]
val
=
LAYERS_TO_COPY
[
n_teacher
][
n_student
]
assert
len
(
LAYERS_TO_SUPERVISE
[
n_teacher
][
n_student
])
==
len
(
val
)
==
n_student
return
val
except
KeyError
:
except
KeyError
:
warnings
.
warn
(
if
n_student
!=
n_teacher
:
f
"no hardcoded layers to copy for teacher
{
n_teacher
}
-> student
{
n_student
}
, defaulting to first
{
n_student
}
"
warnings
.
warn
(
)
f
"no hardcoded layers to copy for teacher
{
n_teacher
}
-> student
{
n_student
}
, defaulting to first
{
n_student
}
"
)
return
list
(
range
(
n_student
))
return
list
(
range
(
n_student
))
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
6078b120
...
@@ -31,6 +31,8 @@ logging.basicConfig(level=logging.DEBUG)
...
@@ -31,6 +31,8 @@ logging.basicConfig(level=logging.DEBUG)
logger
=
logging
.
getLogger
()
logger
=
logging
.
getLogger
()
CUDA_AVAILABLE
=
torch
.
cuda
.
is_available
()
CUDA_AVAILABLE
=
torch
.
cuda
.
is_available
()
CHEAP_ARGS
=
{
CHEAP_ARGS
=
{
"supervise_forward"
:
True
,
"normalize_hidden"
:
True
,
"label_smoothing"
:
0.2
,
"label_smoothing"
:
0.2
,
"eval_beams"
:
1
,
"eval_beams"
:
1
,
"val_metric"
:
"loss"
,
"val_metric"
:
"loss"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment