"git@developer.sourcefind.cn:change/sglang.git" did not exist on "eb4b015f1219e9b27c9ab5766ff24056a2227a68"
Unverified Commit 4bd7be9a authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

s2s distillation uses AutoModelForSeqToSeqLM (#6761)

parent 05e7150a
...@@ -10,7 +10,7 @@ from torch import nn ...@@ -10,7 +10,7 @@ from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from lightning_base import generic_train from lightning_base import generic_train
from transformers import BartConfig, BartForConditionalGeneration, MBartTokenizer, T5Config, T5ForConditionalGeneration from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
try: try:
...@@ -74,22 +74,22 @@ class BartSummarizationDistiller(SummarizationModule): ...@@ -74,22 +74,22 @@ class BartSummarizationDistiller(SummarizationModule):
def pre_init(self, hparams): def pre_init(self, hparams):
self.output_dir = Path(hparams.output_dir) self.output_dir = Path(hparams.output_dir)
self.output_dir.mkdir(exist_ok=True) self.output_dir.mkdir(exist_ok=True)
teacher = BartForConditionalGeneration.from_pretrained(hparams.teacher).eval() teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval()
student_updates = { student_updates = {
"decoder_layers": hparams.student_decoder_layers, "decoder_layers": hparams.student_decoder_layers,
"encoder_layers": hparams.student_encoder_layers, "encoder_layers": hparams.student_encoder_layers,
} }
if hparams.length_penalty != -1: if hparams.length_penalty != -1:
student_updates["length_penalty"] = hparams.length_penalty student_updates["length_penalty"] = hparams.length_penalty
d_layers_to_copy = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers) d_layers_to_copy: List = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers) e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
hparams.d_layer_to_copy = d_layers_to_copy hparams.d_layer_to_copy = d_layers_to_copy
hparams.e_layer_to_copy = e_layers_to_copy hparams.e_layer_to_copy = e_layers_to_copy
kw = teacher.config.to_diff_dict() kw = teacher.config.to_diff_dict()
kw.update(student_updates) kw.update(student_updates)
# Copy weights # Copy weights
student_cfg = BartConfig(**kw) student_cfg = teacher.config_class(**kw)
student = BartForConditionalGeneration(student_cfg) student = type(teacher)(student_cfg)
student, _ = init_student(student, teacher) student, _ = init_student(student, teacher)
save_dir = self.output_dir.joinpath("student") save_dir = self.output_dir.joinpath("student")
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher) self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
...@@ -252,7 +252,6 @@ class BartTranslationDistiller(BartSummarizationDistiller): ...@@ -252,7 +252,6 @@ class BartTranslationDistiller(BartSummarizationDistiller):
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
assert isinstance(self.tokenizer, MBartTokenizer)
assert hparams.src_lang is not None assert hparams.src_lang is not None
assert hparams.tgt_lang is not None assert hparams.tgt_lang is not None
self.dataset_kwargs["src_lang"] = hparams.src_lang self.dataset_kwargs["src_lang"] = hparams.src_lang
......
...@@ -186,6 +186,7 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -186,6 +186,7 @@ class TestSummarizationDistiller(unittest.TestCase):
tgt_lang="ro_RO", tgt_lang="ro_RO",
) )
model = self._test_distiller_cli(updates, check_contents=False) model = self._test_distiller_cli(updates, check_contents=False)
assert model.model.config.model_type == "mbart"
ckpts = list(Path(model.output_dir).glob("*.ckpt")) ckpts = list(Path(model.output_dir).glob("*.ckpt"))
self.assertEqual(1, len(ckpts)) self.assertEqual(1, len(ckpts))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment