Unverified Commit f94a52cd authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[s2s] add BartTranslationDistiller for distilling mBART (#6363)

parent d2370e1b
...@@ -10,20 +10,40 @@ from torch import nn ...@@ -10,20 +10,40 @@ from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from lightning_base import generic_train from lightning_base import generic_train
from transformers import AdamW, BartConfig, BartForConditionalGeneration, T5Config, T5ForConditionalGeneration from transformers import (
AdamW,
BartConfig,
BartForConditionalGeneration,
MBartTokenizer,
T5Config,
T5ForConditionalGeneration,
)
try: try:
from .finetune import SummarizationModule from .finetune import SummarizationModule, TranslationModule
from .finetune import main as ft_main
from .initialization_utils import init_student, copy_layers from .initialization_utils import init_student, copy_layers
from .utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad from .utils import (
use_task_specific_params,
pickle_load,
freeze_params,
assert_all_frozen,
any_requires_grad,
calculate_bleu_score,
)
from .finetune import main as ft_main
except ImportError: except ImportError:
from finetune import SummarizationModule from finetune import SummarizationModule, TranslationModule
from finetune import main as ft_main from finetune import main as ft_main
from initialization_utils import init_student, copy_layers from initialization_utils import init_student, copy_layers
from utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad from utils import (
use_task_specific_params,
pickle_load,
freeze_params,
assert_all_frozen,
any_requires_grad,
calculate_bleu_score,
)
class BartSummarizationDistiller(SummarizationModule): class BartSummarizationDistiller(SummarizationModule):
...@@ -159,17 +179,7 @@ class BartSummarizationDistiller(SummarizationModule): ...@@ -159,17 +179,7 @@ class BartSummarizationDistiller(SummarizationModule):
@staticmethod @staticmethod
def add_model_specific_args(parser, root_dir): def add_model_specific_args(parser, root_dir):
SummarizationModule.add_model_specific_args(parser, root_dir) SummarizationModule.add_model_specific_args(parser, root_dir)
parser.add_argument("--teacher", default="facebook/bart-large-cnn", type=str) add_distill_args(parser)
parser.add_argument("--alpha_ce", default=0.8, type=float)
parser.add_argument("--alpha_mlm", default=0.2, type=float)
# parser.add_argument("--alpha_cos", default=0.0, type=float)
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
parser.add_argument("--no_teacher", action="store_true", default=False)
parser.add_argument("--length_penalty", type=float, default=-1)
return parser return parser
def _step(self, batch): def _step(self, batch):
...@@ -247,6 +257,44 @@ class BartSummarizationDistiller(SummarizationModule): ...@@ -247,6 +257,44 @@ class BartSummarizationDistiller(SummarizationModule):
return sum(hidden_losses) return sum(hidden_losses)
def add_distill_args(parser):
parser.add_argument("--teacher", default="facebook/bart-large-cnn", type=str)
parser.add_argument("--alpha_ce", default=0.8, type=float)
parser.add_argument("--alpha_mlm", default=0.2, type=float)
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
parser.add_argument("--no_teacher", action="store_true", default=False)
parser.add_argument("--length_penalty", type=float, default=-1)
class BartTranslationDistiller(BartSummarizationDistiller):
mode = "translation"
loss_names = ["loss"]
metric_names = ["bleu"]
val_metric = "bleu"
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
assert isinstance(self.tokenizer, MBartTokenizer)
assert hparams.src_lang is not None
assert hparams.tgt_lang is not None
self.dataset_kwargs["src_lang"] = hparams.src_lang
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
def calc_generative_metrics(self, preds, target) -> dict:
return calculate_bleu_score(preds, target)
@staticmethod
def add_model_specific_args(parser, root_dir):
TranslationModule.add_model_specific_args(parser, root_dir)
add_distill_args(parser)
return parser
class T5SummarizationDistiller(BartSummarizationDistiller): class T5SummarizationDistiller(BartSummarizationDistiller):
def pre_init(self, hparams): def pre_init(self, hparams):
raise NotImplementedError("T5 Distillation does not work yet") raise NotImplementedError("T5 Distillation does not work yet")
...@@ -364,15 +412,14 @@ class T5SummarizationDistiller(BartSummarizationDistiller): ...@@ -364,15 +412,14 @@ class T5SummarizationDistiller(BartSummarizationDistiller):
def create_module(args): def create_module(args):
t5 = "t5" in args.model_name_or_path t5 = "t5" in args.model_name_or_path
if args.no_teacher: if args.no_teacher:
assert not args.enc_only module_cls = TranslationModule if "translation" in args.task else SummarizationModule
module_cls = SummarizationModule elif t5: # DISTILL T5 WITH TEACHER FOR SUMMARIZATION
elif t5: assert "translation" not in args.task, "t5 translation distillation not supported"
module_cls = T5SummarizationDistiller module_cls = T5SummarizationDistiller
elif args.enc_only: else: # DISTILL WITH TEACHER
raise ValueError("Deleted that") module_cls = BartTranslationDistiller if "translation" in args.task else BartSummarizationDistiller
else:
module_cls = BartSummarizationDistiller
args.setup_cls: str = module_cls.__name__ args.setup_cls: str = module_cls.__name__
print(f"using module {args.setup_cls}")
model = module_cls(args) model = module_cls(args)
return model return model
......
...@@ -166,6 +166,31 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -166,6 +166,31 @@ class TestSummarizationDistiller(unittest.TestCase):
# TODO: understand why this breaks # TODO: understand why this breaks
self.assertEqual(nll_loss, model_computed_loss) self.assertEqual(nll_loss, model_computed_loss)
def test_distill_mbart(self):
updates = dict(
student_encoder_layers=2,
student_decoder_layers=1,
num_train_epochs=4,
val_check_interval=0.25,
alpha_hid=2.0,
task="translation",
model_name_or_path="IGNORE_THIS_IT_DOESNT_GET_USED",
tokenizer_name=MBART_TINY,
teacher=MBART_TINY,
src_lang="en_XX",
tgt_lang="ro_RO",
)
model = self._test_distiller_cli(updates, check_contents=False)
ckpts = list(Path(model.output_dir).glob("*.ckpt"))
self.assertEqual(1, len(ckpts))
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
all_files = list(Path(model.output_dir).glob("best_tfmr/*"))
assert len(all_files) > 2
self.assertEqual(len(transformer_ckpts), 2)
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
@unittest.skip("T5 distillation is broken at the moment") @unittest.skip("T5 distillation is broken at the moment")
def test_distill_t5(self): def test_distill_t5(self):
updates = dict( updates = dict(
...@@ -180,7 +205,7 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -180,7 +205,7 @@ class TestSummarizationDistiller(unittest.TestCase):
def _test_distiller_cli(self, updates, check_contents=True): def _test_distiller_cli(self, updates, check_contents=True):
default_updates = dict( default_updates = dict(
label_smoothing_eps=0.0, label_smoothing=0.0,
early_stopping_patience=-1, early_stopping_patience=-1,
train_batch_size=1, train_batch_size=1,
eval_batch_size=2, eval_batch_size=2,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment