Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f94a52cd
Unverified
Commit
f94a52cd
authored
Aug 12, 2020
by
Sam Shleifer
Committed by
GitHub
Aug 12, 2020
Browse files
[s2s] add BartTranslationDistiller for distilling mBART (#6363)
parent
d2370e1b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
98 additions
and
26 deletions
+98
-26
examples/seq2seq/distillation.py
examples/seq2seq/distillation.py
+72
-25
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+26
-1
No files found.
examples/seq2seq/distillation.py
View file @
f94a52cd
...
@@ -10,20 +10,40 @@ from torch import nn
...
@@ -10,20 +10,40 @@ from torch import nn
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
lightning_base
import
generic_train
from
lightning_base
import
generic_train
from
transformers
import
AdamW
,
BartConfig
,
BartForConditionalGeneration
,
T5Config
,
T5ForConditionalGeneration
from
transformers
import
(
AdamW
,
BartConfig
,
BartForConditionalGeneration
,
MBartTokenizer
,
T5Config
,
T5ForConditionalGeneration
,
)
try
:
try
:
from
.finetune
import
SummarizationModule
from
.finetune
import
SummarizationModule
,
TranslationModule
from
.finetune
import
main
as
ft_main
from
.initialization_utils
import
init_student
,
copy_layers
from
.initialization_utils
import
init_student
,
copy_layers
from
.utils
import
use_task_specific_params
,
pickle_load
,
freeze_params
,
assert_all_frozen
,
any_requires_grad
from
.utils
import
(
use_task_specific_params
,
pickle_load
,
freeze_params
,
assert_all_frozen
,
any_requires_grad
,
calculate_bleu_score
,
)
from
.finetune
import
main
as
ft_main
except
ImportError
:
except
ImportError
:
from
finetune
import
SummarizationModule
from
finetune
import
SummarizationModule
,
TranslationModule
from
finetune
import
main
as
ft_main
from
finetune
import
main
as
ft_main
from
initialization_utils
import
init_student
,
copy_layers
from
initialization_utils
import
init_student
,
copy_layers
from
utils
import
use_task_specific_params
,
pickle_load
,
freeze_params
,
assert_all_frozen
,
any_requires_grad
from
utils
import
(
use_task_specific_params
,
pickle_load
,
freeze_params
,
assert_all_frozen
,
any_requires_grad
,
calculate_bleu_score
,
)
class
BartSummarizationDistiller
(
SummarizationModule
):
class
BartSummarizationDistiller
(
SummarizationModule
):
...
@@ -159,17 +179,7 @@ class BartSummarizationDistiller(SummarizationModule):
...
@@ -159,17 +179,7 @@ class BartSummarizationDistiller(SummarizationModule):
@
staticmethod
@
staticmethod
def
add_model_specific_args
(
parser
,
root_dir
):
def
add_model_specific_args
(
parser
,
root_dir
):
SummarizationModule
.
add_model_specific_args
(
parser
,
root_dir
)
SummarizationModule
.
add_model_specific_args
(
parser
,
root_dir
)
parser
.
add_argument
(
"--teacher"
,
default
=
"facebook/bart-large-cnn"
,
type
=
str
)
add_distill_args
(
parser
)
parser
.
add_argument
(
"--alpha_ce"
,
default
=
0.8
,
type
=
float
)
parser
.
add_argument
(
"--alpha_mlm"
,
default
=
0.2
,
type
=
float
)
# parser.add_argument("--alpha_cos", default=0.0, type=float)
parser
.
add_argument
(
"--alpha_encoder_loss"
,
default
=
0.0
,
type
=
float
)
parser
.
add_argument
(
"--alpha_hid"
,
default
=
0.0
,
type
=
float
,
required
=
False
)
parser
.
add_argument
(
"--student_decoder_layers"
,
default
=
12
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--student_encoder_layers"
,
default
=
12
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--no_teacher"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--length_penalty"
,
type
=
float
,
default
=-
1
)
return
parser
return
parser
def
_step
(
self
,
batch
):
def
_step
(
self
,
batch
):
...
@@ -247,6 +257,44 @@ class BartSummarizationDistiller(SummarizationModule):
...
@@ -247,6 +257,44 @@ class BartSummarizationDistiller(SummarizationModule):
return
sum
(
hidden_losses
)
return
sum
(
hidden_losses
)
def
add_distill_args
(
parser
):
parser
.
add_argument
(
"--teacher"
,
default
=
"facebook/bart-large-cnn"
,
type
=
str
)
parser
.
add_argument
(
"--alpha_ce"
,
default
=
0.8
,
type
=
float
)
parser
.
add_argument
(
"--alpha_mlm"
,
default
=
0.2
,
type
=
float
)
parser
.
add_argument
(
"--alpha_encoder_loss"
,
default
=
0.0
,
type
=
float
)
parser
.
add_argument
(
"--alpha_hid"
,
default
=
0.0
,
type
=
float
,
required
=
False
)
parser
.
add_argument
(
"--student_decoder_layers"
,
default
=
12
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--student_encoder_layers"
,
default
=
12
,
type
=
int
,
required
=
False
)
parser
.
add_argument
(
"--no_teacher"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--length_penalty"
,
type
=
float
,
default
=-
1
)
class
BartTranslationDistiller
(
BartSummarizationDistiller
):
mode
=
"translation"
loss_names
=
[
"loss"
]
metric_names
=
[
"bleu"
]
val_metric
=
"bleu"
def
__init__
(
self
,
hparams
,
**
kwargs
):
super
().
__init__
(
hparams
,
**
kwargs
)
assert
isinstance
(
self
.
tokenizer
,
MBartTokenizer
)
assert
hparams
.
src_lang
is
not
None
assert
hparams
.
tgt_lang
is
not
None
self
.
dataset_kwargs
[
"src_lang"
]
=
hparams
.
src_lang
self
.
dataset_kwargs
[
"tgt_lang"
]
=
hparams
.
tgt_lang
if
self
.
model
.
config
.
decoder_start_token_id
is
None
and
isinstance
(
self
.
tokenizer
,
MBartTokenizer
):
self
.
decoder_start_token_id
=
self
.
tokenizer
.
lang_code_to_id
[
hparams
.
tgt_lang
]
def
calc_generative_metrics
(
self
,
preds
,
target
)
->
dict
:
return
calculate_bleu_score
(
preds
,
target
)
@
staticmethod
def
add_model_specific_args
(
parser
,
root_dir
):
TranslationModule
.
add_model_specific_args
(
parser
,
root_dir
)
add_distill_args
(
parser
)
return
parser
class
T5SummarizationDistiller
(
BartSummarizationDistiller
):
class
T5SummarizationDistiller
(
BartSummarizationDistiller
):
def
pre_init
(
self
,
hparams
):
def
pre_init
(
self
,
hparams
):
raise
NotImplementedError
(
"T5 Distillation does not work yet"
)
raise
NotImplementedError
(
"T5 Distillation does not work yet"
)
...
@@ -364,15 +412,14 @@ class T5SummarizationDistiller(BartSummarizationDistiller):
...
@@ -364,15 +412,14 @@ class T5SummarizationDistiller(BartSummarizationDistiller):
def
create_module
(
args
):
def
create_module
(
args
):
t5
=
"t5"
in
args
.
model_name_or_path
t5
=
"t5"
in
args
.
model_name_or_path
if
args
.
no_teacher
:
if
args
.
no_teacher
:
assert
not
args
.
enc_only
module_cls
=
TranslationModule
if
"translation"
in
args
.
task
else
SummarizationModule
module_cls
=
SummarizationModule
elif
t5
:
# DISTILL T5 WITH TEACHER FOR SUMMARIZATION
elif
t5
:
assert
"translation"
not
in
args
.
task
,
"t5 translation distillation not supported"
module_cls
=
T5SummarizationDistiller
module_cls
=
T5SummarizationDistiller
elif
args
.
enc_only
:
else
:
# DISTILL WITH TEACHER
raise
ValueError
(
"Deleted that"
)
module_cls
=
BartTranslationDistiller
if
"translation"
in
args
.
task
else
BartSummarizationDistiller
else
:
module_cls
=
BartSummarizationDistiller
args
.
setup_cls
:
str
=
module_cls
.
__name__
args
.
setup_cls
:
str
=
module_cls
.
__name__
print
(
f
"using module
{
args
.
setup_cls
}
"
)
model
=
module_cls
(
args
)
model
=
module_cls
(
args
)
return
model
return
model
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
f94a52cd
...
@@ -166,6 +166,31 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -166,6 +166,31 @@ class TestSummarizationDistiller(unittest.TestCase):
# TODO: understand why this breaks
# TODO: understand why this breaks
self
.
assertEqual
(
nll_loss
,
model_computed_loss
)
self
.
assertEqual
(
nll_loss
,
model_computed_loss
)
def
test_distill_mbart
(
self
):
updates
=
dict
(
student_encoder_layers
=
2
,
student_decoder_layers
=
1
,
num_train_epochs
=
4
,
val_check_interval
=
0.25
,
alpha_hid
=
2.0
,
task
=
"translation"
,
model_name_or_path
=
"IGNORE_THIS_IT_DOESNT_GET_USED"
,
tokenizer_name
=
MBART_TINY
,
teacher
=
MBART_TINY
,
src_lang
=
"en_XX"
,
tgt_lang
=
"ro_RO"
,
)
model
=
self
.
_test_distiller_cli
(
updates
,
check_contents
=
False
)
ckpts
=
list
(
Path
(
model
.
output_dir
).
glob
(
"*.ckpt"
))
self
.
assertEqual
(
1
,
len
(
ckpts
))
transformer_ckpts
=
list
(
Path
(
model
.
output_dir
).
glob
(
"**/*.bin"
))
all_files
=
list
(
Path
(
model
.
output_dir
).
glob
(
"best_tfmr/*"
))
assert
len
(
all_files
)
>
2
self
.
assertEqual
(
len
(
transformer_ckpts
),
2
)
evaluate_checkpoint
(
ckpts
[
0
],
dest_dir
=
Path
(
tempfile
.
mkdtemp
()))
@
unittest
.
skip
(
"T5 distillation is broken at the moment"
)
@
unittest
.
skip
(
"T5 distillation is broken at the moment"
)
def
test_distill_t5
(
self
):
def
test_distill_t5
(
self
):
updates
=
dict
(
updates
=
dict
(
...
@@ -180,7 +205,7 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -180,7 +205,7 @@ class TestSummarizationDistiller(unittest.TestCase):
def
_test_distiller_cli
(
self
,
updates
,
check_contents
=
True
):
def
_test_distiller_cli
(
self
,
updates
,
check_contents
=
True
):
default_updates
=
dict
(
default_updates
=
dict
(
label_smoothing
_eps
=
0.0
,
label_smoothing
=
0.0
,
early_stopping_patience
=-
1
,
early_stopping_patience
=-
1
,
train_batch_size
=
1
,
train_batch_size
=
1
,
eval_batch_size
=
2
,
eval_batch_size
=
2
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment