distillation.py 13.5 KB
Newer Older
1
2
#!/usr/bin/env python

3
4
5
import argparse
import gc
import os
6
import sys
7
8
9
10
11
12
13
14
from pathlib import Path
from typing import List

import pytorch_lightning as pl
import torch
from torch import nn
from torch.nn import functional as F

15
16
from finetune import SummarizationModule, TranslationModule
from finetune import main as ft_main
17
18
from make_student import create_student_by_copying_alternating_layers, get_layers_to_supervise
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5ForConditionalGeneration
19
from transformers.modeling_bart import shift_tokens_right
20
from utils import calculate_bleu, freeze_params, label_smoothed_nll_loss, pickle_load, use_task_specific_params
21
22


23
24
25
26
27
# need the parent dir module
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
from lightning_base import generic_train  # noqa


28
class BartSummarizationDistiller(SummarizationModule):
29
30
    """Supports Bart, Pegasus and other models that inherit from Bart."""

31
    loss_names = ["loss", "ce_loss", "mlm_loss", "hid_loss_enc", "hid_loss_dec"]
32
33
34

    def __init__(self, hparams):
        assert Path(hparams.data_dir).exists()
35
36
37
38
        self.output_dir = Path(hparams.output_dir)
        self.output_dir.mkdir(exist_ok=True)

        save_dir = self.output_dir.joinpath("student")
39

40
41
42
43
44
45
46
47
48
        hparams.model_name_or_path = str(save_dir)  # Tell lightning we are training the student
        teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval()
        use_task_specific_params(teacher, hparams.task)  # We copy good generation parameters to student by default
        student, e_layer_ids, d_layer_ids = create_student_by_copying_alternating_layers(
            teacher, e=hparams.student_encoder_layers, d=hparams.student_decoder_layers, save_path=save_dir
        )
        if hparams.length_penalty != -1:
            student.config.length_penalty = hparams.length_penalty
        super().__init__(hparams, model=student, config=student.config)
49
        model_type = student.config.model_type
50
        self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids  # type: List[int], List[int]
51
52
53
54
55
56
57
58
59
60
61

        if model_type == "t5":
            teacher_encoder_layers = len(teacher.get_encoder().block)
            teacher_decoder_layers = len(teacher.get_decoder().block)
        else:
            teacher_encoder_layers = teacher.config.encoder_layers
            teacher_decoder_layers = teacher.config.decoder_layers

        self.different_encoder = hparams.student_encoder_layers != teacher_encoder_layers
        self.different_decoder = hparams.student_decoder_layers != teacher_decoder_layers

62
63
        self.teacher = teacher
        freeze_params(self.teacher)
64
65
66
67
68
69
70
71

        if not self.different_encoder:  # To save RAM, delete teacher encoder and freeze student encoder.
            try:
                del self.teacher.model.encoder
            except AttributeError:  # T5
                del self.teacher.encoder
        # Intermediate supervision: Decide which layers to supervise
        if hparams.supervise_forward:
72
73
74
75
            self.e_matches = get_layers_to_supervise(n_student=len(self.e_layer_ids), n_teacher=teacher_encoder_layers)
            self.d_matches = get_layers_to_supervise(n_student=len(self.d_layer_ids), n_teacher=teacher_decoder_layers)
        else:  # student layer should emulate hidden states of the teacher layer it was copied from
            self.e_matches = self.e_layer_ids
76
            self.d_matches = self.d_layer_ids
77

78
79
80
81
82
83
84
85
86
        self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
        self.temperature = 2.0
        self.alpha_mlm = hparams.alpha_mlm
        self.alpha_ce = hparams.alpha_ce
        self.alpha_hid = hparams.alpha_hid
        gc.collect()
        torch.cuda.empty_cache()

    def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor:
87
88
        """Supervise MSE(teacher.encoder_outputs, student.encoder_outputs)."""
        # raise NotImplementedError()
89
90
91
92
93
94
95
96
97
98
99
        if mask is not None:
            # mask has False at padding_idx
            sel_mask = mask[:, :, None].expand_as(student_outputs).bool()
            s_logits_slct = torch.masked_select(student_outputs, sel_mask)
            t_logits_slct = torch.masked_select(teacher_outputs, sel_mask)
        else:
            t_logits_slct = teacher_outputs
            s_logits_slct = student_outputs
        return F.mse_loss(s_logits_slct, t_logits_slct)

    def calc_ce_loss(self, mask, s_logits, t_logits):
100
101
102
103
104
105
106
107
108
        """Copy pasted from distillbert (transformers/examples/distillation/)"""

        # mask has False at padding_idx
        sel_mask = mask[:, :, None].expand_as(s_logits)
        vocab_size = s_logits.size(-1)
        s_logits_slct = torch.masked_select(s_logits, sel_mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        t_logits_slct = torch.masked_select(t_logits, sel_mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        s_logits_slct = s_logits_slct.view(-1, vocab_size)  # (bs * seq_length, voc_size) modulo the 1s in mask
        t_logits_slct = t_logits_slct.view(-1, vocab_size)  # (bs * seq_length, voc_size) modulo the 1s in mask
109
110
111
112
113
114
115
116
        assert t_logits_slct.size() == s_logits_slct.size()
        loss_ce = (
            self.ce_loss_fct(
                F.log_softmax(s_logits_slct / self.temperature, dim=-1),
                F.softmax(t_logits_slct / self.temperature, dim=-1),
            )
            * (self.temperature) ** 2
        )
117
        return loss_ce
118
119
120
121

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        SummarizationModule.add_model_specific_args(parser, root_dir)
122
        add_distill_args(parser)
123
124
125
        return parser

    def _step(self, batch):
126
        # assert is_frozen(self.teacher) copied_decoder_layers
127
        pad_token_id = self.tokenizer.pad_token_id
128
129
130
131
132
133
        input_ids, src_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"]
        if isinstance(self.model, T5ForConditionalGeneration):
            decoder_input_ids = self.model._shift_right(labels)
        else:
            decoder_input_ids = shift_tokens_right(labels, pad_token_id)

134
        # noinspection PyCallingNonCallable
135
        lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self(
136
137
138
139
140
            input_ids,
            attention_mask=src_mask,
            decoder_input_ids=decoder_input_ids,
            output_hidden_states=True,
            output_attentions=False,
141
            use_cache=False,
142
        )
143
144
145
146
147
148

        # Same cross entropy vs. label smoothing logic as finetune.py
        assert lm_logits.shape[-1] == self.model.config.vocab_size
        if self.hparams.label_smoothing == 0:
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
149
            student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
150
151
152
        else:
            lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
            student_lm_loss, _ = label_smoothed_nll_loss(
153
                lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id
154
            )
155
156

        def zero_tensor():
157
            return torch.tensor(0.0).type_as(student_lm_loss)
158

159
160
        hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
        if self.different_encoder:  # compute encoder hidden state loss
161
            with torch.no_grad():
162
163
164
165
166
167
168
169
170
171
172
                teacher_enc_hid = self.teacher.get_encoder()(
                    input_ids, attention_mask=src_mask, output_hidden_states=True, return_dict=True
                ).hidden_states

            hid_loss_enc = self.calc_hidden_loss(
                src_mask,
                enc_hidden_state,
                teacher_enc_hid,
                self.e_matches,
                normalize_hidden=self.hparams.normalize_hidden,
            )
173
174

        with torch.no_grad():
175
            outputs = self.teacher(
176
177
                input_ids,
                attention_mask=src_mask,
178
                encoder_outputs=(enc_outputs,),
179
                decoder_input_ids=decoder_input_ids,
180
                lm_labels=labels,
181
                output_hidden_states=True,
182
                return_dict=True,
183
            )
184
            tlogits, tdec_hidden = outputs.logits, outputs.decoder_hidden_states
185
        dec_mask = decoder_input_ids.ne(pad_token_id)
186
187
188
189
190
        loss_ce = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
        if self.alpha_hid > 0:  # Intermediate supervision of decoder hidden states
            hid_loss_dec = self.calc_hidden_loss(
                dec_mask, dec_hidden, tdec_hidden, self.d_matches, normalize_hidden=self.hparams.normalize_hidden
            )
191
192
193

        blended_loss = (
            self.alpha_ce * loss_ce
194
            + self.alpha_mlm * student_lm_loss
195
196
            + self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
        )
197
        return blended_loss, loss_ce, student_lm_loss, hid_loss_enc, hid_loss_dec
198

199
200
201
    @staticmethod
    def calc_hidden_loss(attention_mask, hidden_states, hidden_states_T, matches, normalize_hidden):
        """MSE(student_hid, teacher_hid[matches]). Called "Intermediate supervision" in paper. Inspired by TinyBERT."""
202
203
204
        msg = "expected list or tuple for hidden_states, got tensor of shape: "
        assert not isinstance(hidden_states, torch.Tensor), f"{msg}{hidden_states.shape}"
        assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}"
205
206
        mask = attention_mask.to(hidden_states[0])
        valid_count = mask.sum() * hidden_states[0].size(-1)
207
208
        student_states = torch.stack([hidden_states[i] for i in range(len(matches))])
        teacher_states = torch.stack([hidden_states_T[j] for j in matches])
209
        if normalize_hidden:
210
211
212
213
214
            student_states = F.layer_norm(student_states, student_states.shape[1:])
            teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:])
        mse = F.mse_loss(student_states, teacher_states, reduction="none")
        masked_mse = (mse * mask.unsqueeze(0).unsqueeze(-1)).sum() / valid_count
        return masked_mse
215
216


217
def add_distill_args(parser):
218
    parser.add_argument("--teacher", type=str)
219
220
221
222
223
224
225
    parser.add_argument("--alpha_ce", default=0.8, type=float)
    parser.add_argument("--alpha_mlm", default=0.2, type=float)
    parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
    parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
    parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
    parser.add_argument("--no_teacher", action="store_true", default=False)
    parser.add_argument("--length_penalty", type=float, default=-1)
226
227
    parser.add_argument("--supervise_forward", action="store_true", default=False)
    parser.add_argument("--normalize_hidden", action="store_true", default=False)
228
229
230


class BartTranslationDistiller(BartSummarizationDistiller):
231
232
    """Supports Mbart, Marian, other models that inherit from Bart."""

233
234
    mode = "translation"
    metric_names = ["bleu"]
235
    default_val_metric = "bleu"
236
237
238
239
240
241
242
243
244
245
246

    def __init__(self, hparams, **kwargs):
        super().__init__(hparams, **kwargs)
        assert hparams.src_lang is not None
        assert hparams.tgt_lang is not None
        self.dataset_kwargs["src_lang"] = hparams.src_lang
        self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
        if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
            self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]

    def calc_generative_metrics(self, preds, target) -> dict:
247
        return calculate_bleu(preds, target)
248
249
250
251
252
253
254
255

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        TranslationModule.add_model_specific_args(parser, root_dir)
        add_distill_args(parser)
        return parser


256
257
def create_module(args):
    if args.no_teacher:
258
259
260
        module_cls = TranslationModule if "translation" in args.task else SummarizationModule
    else:  # DISTILL WITH TEACHER
        module_cls = BartTranslationDistiller if "translation" in args.task else BartSummarizationDistiller
261
    args.setup_cls: str = module_cls.__name__
262
    print(f"using module {args.setup_cls}")
263
264
265
266
267
    model = module_cls(args)
    return model


def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
268
    # TODO(SS): DELETE? Better to convert_pl_ckpt_to_hf and run_eval.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
    exp_dir = ckpt_path.parent
    if dest_dir is None:
        dest_dir = exp_dir
    clash = list(dest_dir.glob("test_generations*"))
    if clash:
        print(f"SKIPPING to avoid overwriting {clash}")
    ckpt = torch.load(ckpt_path, map_location="cpu")
    if "hparams" in ckpt:
        args = argparse.Namespace(**ckpt["hparams"])
    else:
        args = argparse.Namespace(**pickle_load(exp_dir / "hparams.pkl"))
    args.resume_from_checkpoint = str(ckpt_path)
    args.do_train = False
    args.output_dir = str(dest_dir)
    args.n_gpu = 1
    args.eval_batch_size = 16
    Path(args.output_dir).mkdir(exist_ok=True)
    model = create_module(args)
    trainer: pl.Trainer = generic_train(model, args, early_stopping_callback=False)
    trainer.test(model)


def distill_main(args):
    Path(args.output_dir).mkdir(exist_ok=True)
    if len(os.listdir(args.output_dir)) > 3 and args.do_train:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))

    model = create_module(args)
    return ft_main(args, model=model)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
302
    parser = pl.Trainer.add_argparse_args(parser)
303
    parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
304
305
306
    args = parser.parse_args()

    distill_main(args)