distillation.py 14.2 KB
Newer Older
1
2
#!/usr/bin/env python

3
4
5
import argparse
import gc
import os
6
import sys
7
8
9
10
11
12
13
14
from pathlib import Path
from typing import List

import pytorch_lightning as pl
import torch
from torch import nn
from torch.nn import functional as F

15
16
from finetune import SummarizationModule, TranslationModule
from finetune import main as ft_main
17
18
from make_student import create_student_by_copying_alternating_layers, get_layers_to_supervise
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5ForConditionalGeneration
Sylvain Gugger's avatar
Sylvain Gugger committed
19
from transformers.models.bart.modeling_bart import shift_tokens_right
20
from utils import calculate_bleu, check_output_dir, freeze_params, label_smoothed_nll_loss, use_task_specific_params
21
22


23
24
25
26
27
# need the parent dir module
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
from lightning_base import generic_train  # noqa


28
29
class SummarizationDistiller(SummarizationModule):
    """Supports T5, Bart, Pegasus and other models that inherit from Bart."""
30

31
    loss_names = ["loss", "ce_loss", "mlm_loss", "hid_loss_enc", "hid_loss_dec"]
32
33
34

    def __init__(self, hparams):
        assert Path(hparams.data_dir).exists()
35
36
37
38
        self.output_dir = Path(hparams.output_dir)
        self.output_dir.mkdir(exist_ok=True)

        save_dir = self.output_dir.joinpath("student")
39

40
41
42
        hparams.model_name_or_path = str(save_dir)  # Tell lightning we are training the student
        teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval()
        use_task_specific_params(teacher, hparams.task)  # We copy good generation parameters to student by default
43
44
45
46
47
48
49
50
51
        if hparams.student is not None:
            student = AutoModelForSeq2SeqLM.from_pretrained(hparams.student)
            use_task_specific_params(student, hparams.task)
            e_layer_ids, d_layer_ids = None, None
        else:
            student, e_layer_ids, d_layer_ids = create_student_by_copying_alternating_layers(
                teacher, e=hparams.student_encoder_layers, d=hparams.student_decoder_layers, save_path=save_dir
            )

52
53
        if hparams.length_penalty != -1:
            student.config.length_penalty = hparams.length_penalty
54
        hparams.tokenizer_name = hparams.teacher  # Use teacher's tokenizer
55
        super().__init__(hparams, model=student, config=student.config)
56
57
58
        assert (
            student.config.model_type == teacher.config.model_type
        ), f"teacher, student model types should be the same, got {student.config.model_type} != {teacher.config.model_type}"
59

60
61
62
        if student.config.model_type == "t5":
            student_encoder_layers = len(student.get_encoder().block)
            student_decoder_layers = len(student.get_decoder().block)
63
64
65
            teacher_encoder_layers = len(teacher.get_encoder().block)
            teacher_decoder_layers = len(teacher.get_decoder().block)
        else:
66
67
            student_encoder_layers = student.config.encoder_layers
            student_decoder_layers = student.config.decoder_layers
68
69
70
            teacher_encoder_layers = teacher.config.encoder_layers
            teacher_decoder_layers = teacher.config.decoder_layers

71
72
73
74
        self.different_base_models = not (hparams.student is None or hparams.teacher == hparams.student)
        self.do_calc_hidden_loss = (not self.different_base_models) and hparams.alpha_hid > 0
        self.different_encoder = self.different_base_models or (student_encoder_layers != teacher_encoder_layers)
        # self.different_encoder determines whether we need to run the teacher encoder
75
76
        self.teacher = teacher
        freeze_params(self.teacher)
77
78
79
80
81
82

        if not self.different_encoder:  # To save RAM, delete teacher encoder and freeze student encoder.
            try:
                del self.teacher.model.encoder
            except AttributeError:  # T5
                del self.teacher.encoder
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

        if e_layer_ids is None:
            e_layer_ids = list(range(student_encoder_layers))
        if d_layer_ids is None:
            d_layer_ids = list(range(student_decoder_layers))

        self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids  # type: List[int], List[int]

        if self.do_calc_hidden_loss:  # Intermediate supervision: Decide which layers to supervise
            if hparams.supervise_forward:
                self.e_matches = get_layers_to_supervise(
                    n_student=len(self.e_layer_ids), n_teacher=teacher_encoder_layers
                )
                self.d_matches = get_layers_to_supervise(
                    n_student=len(self.d_layer_ids), n_teacher=teacher_decoder_layers
                )
            else:  # student layer should emulate hidden states of the teacher layer it was copied from
                self.e_matches = self.e_layer_ids
                self.d_matches = self.d_layer_ids
        else:
            self.e_matches = None
            self.d_matches = None
105

106
107
108
109
110
111
112
113
114
        self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
        self.temperature = 2.0
        self.alpha_mlm = hparams.alpha_mlm
        self.alpha_ce = hparams.alpha_ce
        self.alpha_hid = hparams.alpha_hid
        gc.collect()
        torch.cuda.empty_cache()

    def calc_ce_loss(self, mask, s_logits, t_logits):
115
116
117
118
119
120
121
122
        """Copy pasted from distillbert (transformers/examples/distillation/)"""
        # mask has False at padding_idx
        sel_mask = mask[:, :, None].expand_as(s_logits)
        vocab_size = s_logits.size(-1)
        s_logits_slct = torch.masked_select(s_logits, sel_mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        t_logits_slct = torch.masked_select(t_logits, sel_mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        s_logits_slct = s_logits_slct.view(-1, vocab_size)  # (bs * seq_length, voc_size) modulo the 1s in mask
        t_logits_slct = t_logits_slct.view(-1, vocab_size)  # (bs * seq_length, voc_size) modulo the 1s in mask
123
124
125
126
127
128
129
130
        assert t_logits_slct.size() == s_logits_slct.size()
        loss_ce = (
            self.ce_loss_fct(
                F.log_softmax(s_logits_slct / self.temperature, dim=-1),
                F.softmax(t_logits_slct / self.temperature, dim=-1),
            )
            * (self.temperature) ** 2
        )
131
        return loss_ce
132
133
134
135

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        SummarizationModule.add_model_specific_args(parser, root_dir)
136
        add_distill_args(parser)
137
138
        return parser

139
140
    def _step(self, batch: dict) -> tuple:
        """Compute the loss for a batch"""
141
        pad_token_id = self.tokenizer.pad_token_id
142
143
144
145
146
147
        input_ids, src_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"]
        if isinstance(self.model, T5ForConditionalGeneration):
            decoder_input_ids = self.model._shift_right(labels)
        else:
            decoder_input_ids = shift_tokens_right(labels, pad_token_id)

148
        # noinspection PyCallingNonCallable
149
        student_outputs = self(
150
151
152
            input_ids,
            attention_mask=src_mask,
            decoder_input_ids=decoder_input_ids,
153
            output_hidden_states=self.do_calc_hidden_loss,
154
            output_attentions=False,
155
            use_cache=False,
156
        )
157
        lm_logits = student_outputs.logits
158
159
160
161
162
163

        # Same cross entropy vs. label smoothing logic as finetune.py
        assert lm_logits.shape[-1] == self.model.config.vocab_size
        if self.hparams.label_smoothing == 0:
            # Same behavior as modeling_bart.py, besides ignoring pad_token_id
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
164
            student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
165
        else:
166
            lprobs = F.log_softmax(lm_logits, dim=-1)
167
            student_lm_loss, _ = label_smoothed_nll_loss(
168
                lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id
169
            )
170
171

        def zero_tensor():
172
            return torch.tensor(0.0).type_as(student_lm_loss)
173

174
        teacher_enc_outputs = student_outputs.encoder_last_hidden_state  # use this unless self.different_base_models
175
176
        hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
        if self.different_encoder:  # compute encoder hidden state loss
177
            all_teacher_encoder_outputs = self.teacher.get_encoder()(
178
179
                input_ids,
                attention_mask=src_mask,
180
                output_hidden_states=self.do_calc_hidden_loss,
181
            )
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
            if self.different_base_models:
                teacher_enc_outputs = all_teacher_encoder_outputs.last_hidden_state
            elif self.do_calc_hidden_loss:
                hid_loss_enc = self.calc_hidden_loss(
                    src_mask,
                    student_outputs.encoder_hidden_states,
                    all_teacher_encoder_outputs.hidden_states,
                    self.e_matches,
                    normalize_hidden=self.hparams.normalize_hidden,
                )

        teacher_outputs = self.teacher(
            input_ids,
            attention_mask=src_mask,
            encoder_outputs=(teacher_enc_outputs,),
            decoder_input_ids=decoder_input_ids,
            output_hidden_states=self.do_calc_hidden_loss,
            use_cache=False,  # since we are not passing labels, never let this default to True
        )
201
        dec_mask = decoder_input_ids.ne(pad_token_id)
202
203
        loss_ce = self.calc_ce_loss(dec_mask, lm_logits, teacher_outputs.logits)
        if self.do_calc_hidden_loss:  # Intermediate supervision of decoder hidden states
204
            hid_loss_dec = self.calc_hidden_loss(
205
206
207
208
209
                dec_mask,
                student_outputs.decoder_hidden_states,
                teacher_outputs.decoder_hidden_states,
                self.d_matches,
                normalize_hidden=self.hparams.normalize_hidden,
210
            )
211
212
213

        blended_loss = (
            self.alpha_ce * loss_ce
214
            + self.alpha_mlm * student_lm_loss
215
216
            + self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
        )
217
        return blended_loss, loss_ce, student_lm_loss, hid_loss_enc, hid_loss_dec
218

219
220
221
    @staticmethod
    def calc_hidden_loss(attention_mask, hidden_states, hidden_states_T, matches, normalize_hidden):
        """MSE(student_hid, teacher_hid[matches]). Called "Intermediate supervision" in paper. Inspired by TinyBERT."""
222
223
224
        msg = "expected list or tuple for hidden_states, got tensor of shape: "
        assert not isinstance(hidden_states, torch.Tensor), f"{msg}{hidden_states.shape}"
        assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}"
225
226
        mask = attention_mask.to(hidden_states[0])
        valid_count = mask.sum() * hidden_states[0].size(-1)
227
228
        student_states = torch.stack([hidden_states[i] for i in range(len(matches))])
        teacher_states = torch.stack([hidden_states_T[j] for j in matches])
229
        assert student_states.shape == teacher_states.shape, f"{student_states.shape} != {teacher_states.shape}"
230
        if normalize_hidden:
231
232
233
234
235
            student_states = F.layer_norm(student_states, student_states.shape[1:])
            teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:])
        mse = F.mse_loss(student_states, teacher_states, reduction="none")
        masked_mse = (mse * mask.unsqueeze(0).unsqueeze(-1)).sum() / valid_count
        return masked_mse
236
237


238
def add_distill_args(parser):
239
240
241
242
243
    # NOTE: if --student argument was specified and the teacher and student base models
    # are different, the models still have to have the same tokenizer, specified by
    # --tokenizer_name. So, for example, you can distill from t5_large to t5_small but not
    # from bart to t5. This s because if the tokenizers are different, the output space
    # for the two models is also different and their logits are not comparable.
244
    parser.add_argument("--teacher", type=str)
245
246
247
    parser.add_argument("--alpha_ce", default=0.8, type=float)
    parser.add_argument("--alpha_mlm", default=0.2, type=float)
    parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
248
    parser.add_argument("--student", type=str, required=False)
249
250
251
252
    parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
    parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
    parser.add_argument("--no_teacher", action="store_true", default=False)
    parser.add_argument("--length_penalty", type=float, default=-1)
253
254
    parser.add_argument("--supervise_forward", action="store_true", default=False)
    parser.add_argument("--normalize_hidden", action="store_true", default=False)
255
256


257
258
class TranslationDistiller(SummarizationDistiller):
    """Supports T5, mBART, Marian, other models that inherit from Bart."""
259

260
261
    mode = "translation"
    metric_names = ["bleu"]
262
    default_val_metric = "bleu"
263
264
265
266
267
268
269
270
271
272
273

    def __init__(self, hparams, **kwargs):
        super().__init__(hparams, **kwargs)
        assert hparams.src_lang is not None
        assert hparams.tgt_lang is not None
        self.dataset_kwargs["src_lang"] = hparams.src_lang
        self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
        if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
            self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]

    def calc_generative_metrics(self, preds, target) -> dict:
274
        return calculate_bleu(preds, target)
275
276
277
278
279
280
281
282

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        TranslationModule.add_model_specific_args(parser, root_dir)
        add_distill_args(parser)
        return parser


283
284
def create_module(args):
    if args.no_teacher:
285
286
        module_cls = TranslationModule if "translation" in args.task else SummarizationModule
    else:  # DISTILL WITH TEACHER
287
        module_cls = TranslationDistiller if "translation" in args.task else SummarizationDistiller
288
    args.setup_cls: str = module_cls.__name__
289
    print(f"using module {args.setup_cls}")
290
291
292
293
294
295
    model = module_cls(args)
    return model


def distill_main(args):
    Path(args.output_dir).mkdir(exist_ok=True)
296
    check_output_dir(args, expected_items=3)
297
298
299
300
301
302
303

    model = create_module(args)
    return ft_main(args, model=model)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
304
    parser = pl.Trainer.add_argparse_args(parser)
305
    parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd())
306
307
308
    args = parser.parse_args()

    distill_main(args)