"examples/summarization/bart/__init__.py" did not exist on "0bab55d5d52e4d538888980d05d73acc6da6274a"
distillation.py 13.3 KB
Newer Older
1
2
#!/usr/bin/env python

3
4
5
import argparse
import gc
import os
6
import sys
7
8
9
10
11
12
13
14
from pathlib import Path
from typing import List

import pytorch_lightning as pl
import torch
from torch import nn
from torch.nn import functional as F

15
16
from finetune import SummarizationModule, TranslationModule
from finetune import main as ft_main
17
18
from make_student import create_student_by_copying_alternating_layers, get_layers_to_supervise
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5ForConditionalGeneration
19
from transformers.modeling_bart import shift_tokens_right
20
from utils import calculate_bleu, freeze_params, label_smoothed_nll_loss, pickle_load, use_task_specific_params
21
22


23
24
25
26
27
# need the parent dir module
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
from lightning_base import generic_train  # noqa


28
class BartSummarizationDistiller(SummarizationModule):
29
30
    """Supports Bart, Pegasus and other models that inherit from Bart."""

31
32
33
34
    loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"]

    def __init__(self, hparams):
        assert Path(hparams.data_dir).exists()
35
36
37
38
        self.output_dir = Path(hparams.output_dir)
        self.output_dir.mkdir(exist_ok=True)

        save_dir = self.output_dir.joinpath("student")
39

40
41
42
43
44
45
46
47
48
49
50
51
        hparams.model_name_or_path = str(save_dir)  # Tell lightning we are training the student
        teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval()
        use_task_specific_params(teacher, hparams.task)  # We copy good generation parameters to student by default
        student, e_layer_ids, d_layer_ids = create_student_by_copying_alternating_layers(
            teacher, e=hparams.student_encoder_layers, d=hparams.student_decoder_layers, save_path=save_dir
        )
        if hparams.length_penalty != -1:
            student.config.length_penalty = hparams.length_penalty
        super().__init__(hparams, model=student, config=student.config)
        self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids  # type: List[int], List[int]
        self.different_encoder = hparams.student_encoder_layers != teacher.config.encoder_layers
        self.different_decoder = hparams.student_decoder_layers != teacher.config.decoder_layers
52
53
        self.teacher = teacher
        freeze_params(self.teacher)
54
55
56
57
58
59
60
61
62
63
64
65
66

        if not self.different_encoder:  # To save RAM, delete teacher encoder and freeze student encoder.
            try:
                del self.teacher.model.encoder
            except AttributeError:  # T5
                del self.teacher.encoder
        # Intermediate supervision: Decide which layers to supervise
        if hparams.supervise_forward:
            self.d_matches = get_layers_to_supervise(
                n_student=len(self.d_layer_ids), n_teacher=self.teacher.config.decoder_layers
            )
        else:
            self.d_matches = self.d_layer_ids
67
68
69
70
71
        self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
        self.temperature = 2.0
        self.alpha_mlm = hparams.alpha_mlm
        self.alpha_ce = hparams.alpha_ce
        self.alpha_hid = hparams.alpha_hid
72
        self.alpha_encoder_loss = hparams.alpha_encoder_loss
73
74
75
76
        gc.collect()
        torch.cuda.empty_cache()

    def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor:
77
78
        """Supervise MSE(teacher.encoder_outputs, student.encoder_outputs)."""
        # raise NotImplementedError()
79
80
81
82
83
84
85
86
87
88
89
        if mask is not None:
            # mask has False at padding_idx
            sel_mask = mask[:, :, None].expand_as(student_outputs).bool()
            s_logits_slct = torch.masked_select(student_outputs, sel_mask)
            t_logits_slct = torch.masked_select(teacher_outputs, sel_mask)
        else:
            t_logits_slct = teacher_outputs
            s_logits_slct = student_outputs
        return F.mse_loss(s_logits_slct, t_logits_slct)

    def calc_ce_loss(self, mask, s_logits, t_logits):
90
91
92
93
94
95
96
97
98
        """Copy pasted from distillbert (transformers/examples/distillation/)"""

        # mask has False at padding_idx
        sel_mask = mask[:, :, None].expand_as(s_logits)
        vocab_size = s_logits.size(-1)
        s_logits_slct = torch.masked_select(s_logits, sel_mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        t_logits_slct = torch.masked_select(t_logits, sel_mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        s_logits_slct = s_logits_slct.view(-1, vocab_size)  # (bs * seq_length, voc_size) modulo the 1s in mask
        t_logits_slct = t_logits_slct.view(-1, vocab_size)  # (bs * seq_length, voc_size) modulo the 1s in mask
99
100
101
102
103
104
105
106
        assert t_logits_slct.size() == s_logits_slct.size()
        loss_ce = (
            self.ce_loss_fct(
                F.log_softmax(s_logits_slct / self.temperature, dim=-1),
                F.softmax(t_logits_slct / self.temperature, dim=-1),
            )
            * (self.temperature) ** 2
        )
107
        return loss_ce
108
109
110
111

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        SummarizationModule.add_model_specific_args(parser, root_dir)
112
        add_distill_args(parser)
113
114
115
        return parser

    def _step(self, batch):
116
        # assert is_frozen(self.teacher) copied_decoder_layers
117
        pad_token_id = self.tokenizer.pad_token_id
118
119
120
121
122
123
        input_ids, src_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"]
        if isinstance(self.model, T5ForConditionalGeneration):
            decoder_input_ids = self.model._shift_right(labels)
        else:
            decoder_input_ids = shift_tokens_right(labels, pad_token_id)

124
        # noinspection PyCallingNonCallable
125
        lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self(
126
127
128
129
130
            input_ids,
            attention_mask=src_mask,
            decoder_input_ids=decoder_input_ids,
            output_hidden_states=True,
            output_attentions=False,
131
132
133
134
135
136
137
138
            use_cache=False,
        )  # TODO(@sshleifer): return_dict=True cleanup

        # Same cross entropy vs. label smoothing logic as finetune.py
        assert lm_logits.shape[-1] == self.model.config.vocab_size
        if self.hparams.label_smoothing == 0:
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
139
            student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
140
141
142
        else:
            lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
            student_lm_loss, _ = label_smoothed_nll_loss(
143
                lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id
144
            )
145
146

        def zero_tensor():
147
            return torch.tensor(0.0).type_as(student_lm_loss)
148
149
150
151

        loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
        if self.different_encoder:
            with torch.no_grad():
152
                teacher_enc_outputs, teacher_enc_hid, _ = self.teacher.get_encoder()(
153
154
                    input_ids, attention_mask=src_mask, output_hidden_states=True
                )
155
            # DEPRECATE THIS
156
157
158
            if self.hparams.alpha_encoder_loss > 0:
                loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, src_mask)

159
            hid_loss_enc = self.calc_hidden_loss(src_mask, enc_hidden_state, teacher_enc_hid, self.e_layer_ids)
160
161
162
163
164
165
166
167
168
169

        teacher_enc_outputs = (enc_outputs,)
        assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs)

        with torch.no_grad():
            tloss, tlogits, tdec_hidden, _ = self.teacher(
                input_ids,
                attention_mask=src_mask,
                encoder_outputs=teacher_enc_outputs,
                decoder_input_ids=decoder_input_ids,
170
                lm_labels=labels,
171
172
173
                output_hidden_states=True,
            )
        dec_mask = decoder_input_ids.ne(pad_token_id)
174
175
176
177
178
        loss_ce = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
        if self.alpha_hid > 0:  # Intermediate supervision of decoder hidden states
            hid_loss_dec = self.calc_hidden_loss(
                dec_mask, dec_hidden, tdec_hidden, self.d_matches, normalize_hidden=self.hparams.normalize_hidden
            )
179
180
181

        blended_loss = (
            self.alpha_ce * loss_ce
182
            + self.alpha_mlm * student_lm_loss
183
184
185
            + self.hparams.alpha_encoder_loss * loss_encoder
            + self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
        )
186
        return blended_loss, loss_ce, student_lm_loss, loss_encoder, hid_loss_enc, hid_loss_dec
187

188
189
190
    @staticmethod
    def calc_hidden_loss(attention_mask, hidden_states, hidden_states_T, matches, normalize_hidden):
        """MSE(student_hid, teacher_hid[matches]). Called "Intermediate supervision" in paper. Inspired by TinyBERT."""
191
192
193
        msg = "expected list or tuple for hidden_states, got tensor of shape: "
        assert not isinstance(hidden_states, torch.Tensor), f"{msg}{hidden_states.shape}"
        assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}"
194
195
        mask = attention_mask.to(hidden_states[0])
        valid_count = mask.sum() * hidden_states[0].size(-1)
196
197
        student_states = torch.stack([hidden_states[i] for i in range(len(matches))])
        teacher_states = torch.stack([hidden_states_T[j] for j in matches])
198
        if normalize_hidden:
199
200
201
202
203
            student_states = F.layer_norm(student_states, student_states.shape[1:])
            teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:])
        mse = F.mse_loss(student_states, teacher_states, reduction="none")
        masked_mse = (mse * mask.unsqueeze(0).unsqueeze(-1)).sum() / valid_count
        return masked_mse
204
205


206
def add_distill_args(parser):
207
    parser.add_argument("--teacher", type=str)
208
209
210
211
212
213
214
215
    parser.add_argument("--alpha_ce", default=0.8, type=float)
    parser.add_argument("--alpha_mlm", default=0.2, type=float)
    parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
    parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
    parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
    parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
    parser.add_argument("--no_teacher", action="store_true", default=False)
    parser.add_argument("--length_penalty", type=float, default=-1)
216
217
    parser.add_argument("--supervise_forward", action="store_true", default=False)
    parser.add_argument("--normalize_hidden", action="store_true", default=False)
218
219
220


class BartTranslationDistiller(BartSummarizationDistiller):
221
222
    """Supports Mbart, Marian, other models that inherit from Bart."""

223
224
    mode = "translation"
    metric_names = ["bleu"]
225
    default_val_metric = "bleu"
226
227
228
229
230
231
232
233
234
235
236

    def __init__(self, hparams, **kwargs):
        super().__init__(hparams, **kwargs)
        assert hparams.src_lang is not None
        assert hparams.tgt_lang is not None
        self.dataset_kwargs["src_lang"] = hparams.src_lang
        self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
        if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
            self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]

    def calc_generative_metrics(self, preds, target) -> dict:
237
        return calculate_bleu(preds, target)
238
239
240
241
242
243
244
245

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        TranslationModule.add_model_specific_args(parser, root_dir)
        add_distill_args(parser)
        return parser


246
247
def create_module(args):
    if args.no_teacher:
248
249
250
        module_cls = TranslationModule if "translation" in args.task else SummarizationModule
    else:  # DISTILL WITH TEACHER
        module_cls = BartTranslationDistiller if "translation" in args.task else BartSummarizationDistiller
251
    args.setup_cls: str = module_cls.__name__
252
    print(f"using module {args.setup_cls}")
253
254
255
256
257
    model = module_cls(args)
    return model


def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
258
    # TODO(SS): DELETE? Better to convert_pl_ckpt_to_hf and run_eval.py
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
    exp_dir = ckpt_path.parent
    if dest_dir is None:
        dest_dir = exp_dir
    clash = list(dest_dir.glob("test_generations*"))
    if clash:
        print(f"SKIPPING to avoid overwriting {clash}")
    ckpt = torch.load(ckpt_path, map_location="cpu")
    if "hparams" in ckpt:
        args = argparse.Namespace(**ckpt["hparams"])
    else:
        args = argparse.Namespace(**pickle_load(exp_dir / "hparams.pkl"))
    args.resume_from_checkpoint = str(ckpt_path)
    args.do_train = False
    args.output_dir = str(dest_dir)
    args.n_gpu = 1
    args.eval_batch_size = 16
    Path(args.output_dir).mkdir(exist_ok=True)
    model = create_module(args)
    trainer: pl.Trainer = generic_train(model, args, early_stopping_callback=False)
    trainer.test(model)


def distill_main(args):
    Path(args.output_dir).mkdir(exist_ok=True)
    if len(os.listdir(args.output_dir)) > 3 and args.do_train:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))

    model = create_module(args)
    return ft_main(args, model=model)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
292
    parser = pl.Trainer.add_argparse_args(parser)
293
    parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
294
295
296
    args = parser.parse_args()

    distill_main(args)