Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f94a52cd
Unverified
Commit
f94a52cd
authored
Aug 12, 2020
by
Sam Shleifer
Committed by
GitHub
Aug 12, 2020
Browse files
[s2s] add BartTranslationDistiller for distilling mBART (#6363)
parent
d2370e1b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
98 additions
and
26 deletions
+98
-26
examples/seq2seq/distillation.py
examples/seq2seq/distillation.py
+72
-25
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+26
-1
No files found.
examples/seq2seq/distillation.py
View file @
f94a52cd
...
@@ -10,20 +10,40 @@ from torch import nn
...
@@ -10,20 +10,40 @@ from torch import nn
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
lightning_base
import
generic_train
from
lightning_base
import
generic_train
from
transformers
import
AdamW
,
BartConfig
,
BartForConditionalGeneration
,
T5Config
,
T5ForConditionalGeneration
from
transformers
import
(
AdamW
,
BartConfig
,
BartForConditionalGeneration
,
MBartTokenizer
,
T5Config
,
T5ForConditionalGeneration
,
)
try
:
try
:
from
.finetune
import
SummarizationModule
from
.finetune
import
SummarizationModule
,
TranslationModule
from
.finetune
import
main
as
ft_main
from
.initialization_utils
import
init_student
,
copy_layers
from
.initialization_utils
import
init_student
,
copy_layers
from
.utils
import
use_task_specific_params
,
pickle_load
,
freeze_params
,
assert_all_frozen
,
any_requires_grad
from
.utils
import
(
use_task_specific_params
,
pickle_load
,
freeze_params
,
assert_all_frozen
,
any_requires_grad
,
calculate_bleu_score
,
)
from
.finetune
import
main
as
ft_main
except
ImportError
:
except
ImportError
:
from
finetune
import
SummarizationModule
from
finetune
import
SummarizationModule
,
TranslationModule
from
finetune
import
main
as
ft_main
from
finetune
import
main
as
ft_main
from
initialization_utils
import
init_student
,
copy_layers
from
initialization_utils
import
init_student
,
copy_layers
from
utils
import
use_task_specific_params
,
pickle_load
,
freeze_params
,
assert_all_frozen
,
any_requires_grad
from
utils
import
(
use_task_specific_params
,
pickle_load
,
freeze_params
,
assert_all_frozen
,
any_requires_grad
,
calculate_bleu_score
,
)
class
BartSummarizationDistiller
(
SummarizationModule
):
class
BartSummarizationDistiller
(
SummarizationModule
):
...
@@ -159,17 +179,7 @@ class BartSummarizationDistiller(SummarizationModule):
...
@@ -159,17 +179,7 @@ class BartSummarizationDistiller(SummarizationModule):
@
staticmethod
@
staticmethod
def
add_model_specific_args
(
parser
,
root_dir
):
def
add_model_specific_args
(
parser
,
root_dir
):
SummarizationModule
.
add_model_specific_args
(
parser
,
root_dir
)
SummarizationModule
.
add_model_specific_args
(
parser
,
root_dir
)
parser
.
add_argument
(
"--teacher"
,
default
=
"facebook/bart-large-cnn"
,
type
=
str
)
add_distill_args
(
parser
)
parser
.
add_argument
(
"--alpha_ce"
,
default
=
0.8
,
type
=
float
)
parser
.
add_argument
(
"--alpha_mlm"
,
default
=
0.2
,
type
=
float
)
# parser.add_argument("--alpha_cos", default=0.0, type=float)
parser
.
add_argument
(
"--alpha_encoder_loss"
,
default
=
0.0
,
type
=
float
)
parser
.
add_argument
(
"--alpha_hid"
,
default
=
0.0
,
type
=
float
,
required
=
False
)
parser
.
add_argument
(
"--student_decoder_layers"
,
default
=
12
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--student_encoder_layers"
,
default
=
12
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--no_teacher"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--length_penalty"
,
type
=
float
,
default
=-
1
)
return
parser
return
parser
def
_step
(
self
,
batch
):
def
_step
(
self
,
batch
):
...
@@ -247,6 +257,44 @@ class BartSummarizationDistiller(SummarizationModule):
...
@@ -247,6 +257,44 @@ class BartSummarizationDistiller(SummarizationModule):
return
sum
(
hidden_losses
)
return
sum
(
hidden_losses
)
def
add_distill_args
(
parser
):
parser
.
add_argument
(
"--teacher"
,
default
=
"facebook/bart-large-cnn"
,
type
=
str
)
parser
.
add_argument
(
"--alpha_ce"
,
default
=
0.8
,
type
=
float
)
parser
.
add_argument
(
"--alpha_mlm"
,
default
=
0.2
,
type
=
float
)
parser
.
add_argument
(
"--alpha_encoder_loss"
,
default
=
0.0
,
type
=
float
)
parser
.
add_argument
(
"--alpha_hid"
,
default
=
0.0
,
type
=
float
,
required
=
False
)
parser
.
add_argument
(
"--student_decoder_layers"
,
default
=
12
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--student_encoder_layers"
,
default
=
12
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--no_teacher"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--length_penalty"
,
type
=
float
,
default
=-
1
)
class
BartTranslationDistiller
(
BartSummarizationDistiller
):
mode
=
"translation"
loss_names
=
[
"loss"
]
metric_names
=
[
"bleu"
]
val_metric
=
"bleu"
def
__init__
(
self
,
hparams
,
**
kwargs
):
super
().
__init__
(
hparams
,
**
kwargs
)
assert
isinstance
(
self
.
tokenizer
,
MBartTokenizer
)
assert
hparams
.
src_lang
is
not
None
assert
hparams
.
tgt_lang
is
not
None
self
.
dataset_kwargs
[
"src_lang"
]
=
hparams
.
src_lang
self
.
dataset_kwargs
[
"tgt_lang"
]
=
hparams
.
tgt_lang
if
self
.
model
.
config
.
decoder_start_token_id
is
None
and
isinstance
(
self
.
tokenizer
,
MBartTokenizer
):
self
.
decoder_start_token_id
=
self
.
tokenizer
.
lang_code_to_id
[
hparams
.
tgt_lang
]
def
calc_generative_metrics
(
self
,
preds
,
target
)
->
dict
:
return
calculate_bleu_score
(
preds
,
target
)
@
staticmethod
def
add_model_specific_args
(
parser
,
root_dir
):
TranslationModule
.
add_model_specific_args
(
parser
,
root_dir
)
add_distill_args
(
parser
)
return
parser
class
T5SummarizationDistiller
(
BartSummarizationDistiller
):
class
T5SummarizationDistiller
(
BartSummarizationDistiller
):
def
pre_init
(
self
,
hparams
):
def
pre_init
(
self
,
hparams
):
raise
NotImplementedError
(
"T5 Distillation does not work yet"
)
raise
NotImplementedError
(
"T5 Distillation does not work yet"
)
...
@@ -364,15 +412,14 @@ class T5SummarizationDistiller(BartSummarizationDistiller):
...
@@ -364,15 +412,14 @@ class T5SummarizationDistiller(BartSummarizationDistiller):
def
create_module
(
args
):
def
create_module
(
args
):
t5
=
"t5"
in
args
.
model_name_or_path
t5
=
"t5"
in
args
.
model_name_or_path
if
args
.
no_teacher
:
if
args
.
no_teacher
:
assert
not
args
.
enc_only
module_cls
=
TranslationModule
if
"translation"
in
args
.
task
else
SummarizationModule
module_cls
=
SummarizationModule
elif
t5
:
# DISTILL T5 WITH TEACHER FOR SUMMARIZATION
elif
t5
:
assert
"translation"
not
in
args
.
task
,
"t5 translation distillation not supported"
module_cls
=
T5SummarizationDistiller
module_cls
=
T5SummarizationDistiller
elif
args
.
enc_only
:
else
:
# DISTILL WITH TEACHER
raise
ValueError
(
"Deleted that"
)
module_cls
=
BartTranslationDistiller
if
"translation"
in
args
.
task
else
BartSummarizationDistiller
else
:
module_cls
=
BartSummarizationDistiller
args
.
setup_cls
:
str
=
module_cls
.
__name__
args
.
setup_cls
:
str
=
module_cls
.
__name__
print
(
f
"using module
{
args
.
setup_cls
}
"
)
model
=
module_cls
(
args
)
model
=
module_cls
(
args
)
return
model
return
model
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
f94a52cd
...
@@ -166,6 +166,31 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -166,6 +166,31 @@ class TestSummarizationDistiller(unittest.TestCase):
# TODO: understand why this breaks
# TODO: understand why this breaks
self
.
assertEqual
(
nll_loss
,
model_computed_loss
)
self
.
assertEqual
(
nll_loss
,
model_computed_loss
)
def
test_distill_mbart
(
self
):
updates
=
dict
(
student_encoder_layers
=
2
,
student_decoder_layers
=
1
,
num_train_epochs
=
4
,
val_check_interval
=
0.25
,
alpha_hid
=
2.0
,
task
=
"translation"
,
model_name_or_path
=
"IGNORE_THIS_IT_DOESNT_GET_USED"
,
tokenizer_name
=
MBART_TINY
,
teacher
=
MBART_TINY
,
src_lang
=
"en_XX"
,
tgt_lang
=
"ro_RO"
,
)
model
=
self
.
_test_distiller_cli
(
updates
,
check_contents
=
False
)
ckpts
=
list
(
Path
(
model
.
output_dir
).
glob
(
"*.ckpt"
))
self
.
assertEqual
(
1
,
len
(
ckpts
))
transformer_ckpts
=
list
(
Path
(
model
.
output_dir
).
glob
(
"**/*.bin"
))
all_files
=
list
(
Path
(
model
.
output_dir
).
glob
(
"best_tfmr/*"
))
assert
len
(
all_files
)
>
2
self
.
assertEqual
(
len
(
transformer_ckpts
),
2
)
evaluate_checkpoint
(
ckpts
[
0
],
dest_dir
=
Path
(
tempfile
.
mkdtemp
()))
@
unittest
.
skip
(
"T5 distillation is broken at the moment"
)
@
unittest
.
skip
(
"T5 distillation is broken at the moment"
)
def
test_distill_t5
(
self
):
def
test_distill_t5
(
self
):
updates
=
dict
(
updates
=
dict
(
...
@@ -180,7 +205,7 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -180,7 +205,7 @@ class TestSummarizationDistiller(unittest.TestCase):
def
_test_distiller_cli
(
self
,
updates
,
check_contents
=
True
):
def
_test_distiller_cli
(
self
,
updates
,
check_contents
=
True
):
default_updates
=
dict
(
default_updates
=
dict
(
label_smoothing
_eps
=
0.0
,
label_smoothing
=
0.0
,
early_stopping_patience
=-
1
,
early_stopping_patience
=-
1
,
train_batch_size
=
1
,
train_batch_size
=
1
,
eval_batch_size
=
2
,
eval_batch_size
=
2
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment