".github/git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "edfce78c86c7615039dd8d74e89957a0eebbdc03"
Unverified Commit eab5f596 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[s2s] add create student script (#7290)


Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent e50a931c
...@@ -369,7 +369,7 @@ runtime: 6H on NVIDIA RTX 24GB GPU ...@@ -369,7 +369,7 @@ runtime: 6H on NVIDIA RTX 24GB GPU
If you are using `wandb` and comparing the two distillation methods, using this entry point will make your logs consistent, If you are using `wandb` and comparing the two distillation methods, using this entry point will make your logs consistent,
because you will have the same hyperparameters logged in every run. because you will have the same hyperparameters logged in every run.
#### With a teacher #### With a teacher (Intermediate Supervision)
*Note* only BART variants are supported *Note* only BART variants are supported
In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `BartSummarizationDistiller`. In this method, we use try to enforce that the student and teacher produce similar encoder_outputs, logits, and hidden_states using `BartSummarizationDistiller`.
...@@ -378,7 +378,7 @@ This is how `sshleifer/distilbart-xsum*` checkpoints were produced. ...@@ -378,7 +378,7 @@ This is how `sshleifer/distilbart-xsum*` checkpoints were produced.
The command that produced `sshleifer/distilbart-xsum-12-6` is: The command that produced `sshleifer/distilbart-xsum-12-6` is:
```bash ```bash
./train_distilbart_xsum.sh ./train_distilbart_xsum.sh --logger_name wandb --gpus 1
``` ```
runtime: 13H on V-100 16GB GPU. runtime: 13H on V-100 16GB GPU.
......
...@@ -4,7 +4,6 @@ import argparse ...@@ -4,7 +4,6 @@ import argparse
import gc import gc
import os import os
import sys import sys
import warnings
from pathlib import Path from pathlib import Path
from typing import List from typing import List
...@@ -15,18 +14,10 @@ from torch.nn import functional as F ...@@ -15,18 +14,10 @@ from torch.nn import functional as F
from finetune import SummarizationModule, TranslationModule from finetune import SummarizationModule, TranslationModule
from finetune import main as ft_main from finetune import main as ft_main
from initialization_utils import copy_layers, init_student from make_student import create_student_by_copying_alternating_layers, get_layers_to_supervise
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right from transformers.modeling_bart import shift_tokens_right
from utils import ( from utils import calculate_bleu, freeze_params, label_smoothed_nll_loss, pickle_load, use_task_specific_params
any_requires_grad,
assert_all_frozen,
calculate_bleu,
freeze_params,
label_smoothed_nll_loss,
pickle_load,
use_task_specific_params,
)
# need the parent dir module # need the parent dir module
...@@ -41,87 +32,50 @@ class BartSummarizationDistiller(SummarizationModule): ...@@ -41,87 +32,50 @@ class BartSummarizationDistiller(SummarizationModule):
def __init__(self, hparams): def __init__(self, hparams):
assert Path(hparams.data_dir).exists() assert Path(hparams.data_dir).exists()
student, student_cfg, teacher = self.pre_init(hparams) self.output_dir = Path(hparams.output_dir)
self.output_dir.mkdir(exist_ok=True)
save_dir = self.output_dir.joinpath("student")
super().__init__(hparams, model=student, config=student_cfg) hparams.model_name_or_path = str(save_dir) # Tell lightning we are training the student
teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval()
use_task_specific_params(teacher, hparams.task) # We copy good generation parameters to student by default
student, e_layer_ids, d_layer_ids = create_student_by_copying_alternating_layers(
teacher, e=hparams.student_encoder_layers, d=hparams.student_decoder_layers, save_path=save_dir
)
if hparams.length_penalty != -1:
student.config.length_penalty = hparams.length_penalty
super().__init__(hparams, model=student, config=student.config)
self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids # type: List[int], List[int]
self.different_encoder = hparams.student_encoder_layers != teacher.config.encoder_layers
self.different_decoder = hparams.student_decoder_layers != teacher.config.decoder_layers
self.teacher = teacher self.teacher = teacher
use_task_specific_params(self.teacher, "summarization")
freeze_params(self.teacher) freeze_params(self.teacher)
self.sanity_check_gradients()
if not self.different_encoder: # To save RAM, delete teacher encoder and freeze student encoder.
try:
del self.teacher.model.encoder
except AttributeError: # T5
del self.teacher.encoder
# Intermediate supervision: Decide which layers to supervise
if hparams.supervise_forward:
self.d_matches = get_layers_to_supervise(
n_student=len(self.d_layer_ids), n_teacher=self.teacher.config.decoder_layers
)
else:
self.d_matches = self.d_layer_ids
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean") self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
self.temperature = 2.0 self.temperature = 2.0
self.alpha_mlm = hparams.alpha_mlm self.alpha_mlm = hparams.alpha_mlm
self.alpha_ce = hparams.alpha_ce self.alpha_ce = hparams.alpha_ce
self.alpha_hid = hparams.alpha_hid self.alpha_hid = hparams.alpha_hid
# self.alpha_cos = hparams.alpha_cos self.alpha_encoder_loss = hparams.alpha_encoder_loss
self.alpha_encoder_loss = self.hparams.alpha_encoder_loss
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def sanity_check_gradients(self):
assert_all_frozen(self.teacher)
assert_all_frozen(self.model.model.decoder.embed_tokens)
assert_all_frozen(self.model.model.encoder.embed_tokens)
if self.different_encoder:
assert any_requires_grad(self.model.model.encoder)
else:
freeze_params(self.model.model.encoder)
del self.teacher.model.encoder
def pre_init(self, hparams):
self.output_dir = Path(hparams.output_dir)
self.output_dir.mkdir(exist_ok=True)
teacher = AutoModelForSeq2SeqLM.from_pretrained(hparams.teacher).eval()
student_updates = {
"decoder_layers": hparams.student_decoder_layers,
"encoder_layers": hparams.student_encoder_layers,
}
if hparams.length_penalty != -1:
student_updates["length_penalty"] = hparams.length_penalty
e_layers_to_copy: List = get_layers_to_copy(student_updates["encoder_layers"], teacher.config.encoder_layers)
hparams.e_layer_to_copy = e_layers_to_copy
d_layers_to_copy: List = get_layers_to_copy(student_updates["decoder_layers"], teacher.config.decoder_layers)
if hparams.supervise_forward:
hparams.d_matches = get_layers_to_supervise(
student_updates["decoder_layers"], teacher.config.decoder_layers
)
else:
hparams.d_matches = d_layers_to_copy
hparams.d_layer_to_copy = d_layers_to_copy
kw = teacher.config.to_diff_dict()
kw.update(student_updates)
# Copy weights
student_cfg = teacher.config_class(**kw)
student = type(teacher)(student_cfg)
student, _ = init_student(student, teacher)
save_dir = self.output_dir.joinpath("student")
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
student.save_pretrained(save_dir)
hparams.model_name_or_path = str(save_dir)
return student, student_cfg, teacher
def copy_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher):
if teacher.config.model_type == "t5":
return self.copy_t5_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
self.different_encoder: bool = hparams.student_encoder_layers != teacher.config.encoder_layers
self.different_decoder = hparams.student_decoder_layers != teacher.config.decoder_layers
if self.different_decoder:
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy)
if self.different_encoder:
copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy)
def copy_t5_to_student(self, d_layers_to_copy, e_layers_to_copy, hparams, student, teacher):
self.different_encoder: bool = hparams.student_encoder_layers != teacher.config.num_layers
self.different_decoder = hparams.student_decoder_layers != teacher.config.num_layers
if self.different_decoder:
copy_layers(teacher.decoder.block, student.decoder.block, d_layers_to_copy)
if self.different_encoder:
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor: def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor:
"""Supervise MSE(teacher.encoder_outputs, student.encoder_outputs)."""
# raise NotImplementedError()
if mask is not None: if mask is not None:
# mask has False at padding_idx # mask has False at padding_idx
sel_mask = mask[:, :, None].expand_as(student_outputs).bool() sel_mask = mask[:, :, None].expand_as(student_outputs).bool()
...@@ -133,20 +87,15 @@ class BartSummarizationDistiller(SummarizationModule): ...@@ -133,20 +87,15 @@ class BartSummarizationDistiller(SummarizationModule):
return F.mse_loss(s_logits_slct, t_logits_slct) return F.mse_loss(s_logits_slct, t_logits_slct)
def calc_ce_loss(self, mask, s_logits, t_logits): def calc_ce_loss(self, mask, s_logits, t_logits):
if mask is not None: """Copy pasted from distillbert (transformers/examples/distillation/)"""
# mask has False at padding_idx
sel_mask = mask[:, :, None].expand_as(s_logits) # mask has False at padding_idx
s_logits_slct = torch.masked_select( sel_mask = mask[:, :, None].expand_as(s_logits)
s_logits, sel_mask vocab_size = s_logits.size(-1)
) # (bs * seq_length * voc_size) modulo the 1s in mask s_logits_slct = torch.masked_select(s_logits, sel_mask) # (bs * seq_length * voc_size) modulo the 1s in mask
t_logits_slct = torch.masked_select( t_logits_slct = torch.masked_select(t_logits, sel_mask) # (bs * seq_length * voc_size) modulo the 1s in mask
t_logits, sel_mask s_logits_slct = s_logits_slct.view(-1, vocab_size) # (bs * seq_length, voc_size) modulo the 1s in mask
) # (bs * seq_length * voc_size) modulo the 1s in mask t_logits_slct = t_logits_slct.view(-1, vocab_size) # (bs * seq_length, voc_size) modulo the 1s in mask
else:
t_logits_slct = t_logits
s_logits_slct = s_logits # (bs * seq_length * voc_size) modulo the 1s in mask
s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) # (bs * seq_length, voc_size) modulo the 1s in mask
assert t_logits_slct.size() == s_logits_slct.size() assert t_logits_slct.size() == s_logits_slct.size()
loss_ce = ( loss_ce = (
self.ce_loss_fct( self.ce_loss_fct(
...@@ -155,7 +104,7 @@ class BartSummarizationDistiller(SummarizationModule): ...@@ -155,7 +104,7 @@ class BartSummarizationDistiller(SummarizationModule):
) )
* (self.temperature) ** 2 * (self.temperature) ** 2
) )
return loss_ce, s_logits_slct, t_logits_slct return loss_ce
@staticmethod @staticmethod
def add_model_specific_args(parser, root_dir): def add_model_specific_args(parser, root_dir):
...@@ -164,10 +113,14 @@ class BartSummarizationDistiller(SummarizationModule): ...@@ -164,10 +113,14 @@ class BartSummarizationDistiller(SummarizationModule):
return parser return parser
def _step(self, batch): def _step(self, batch):
# assert is_frozen(self.teacher) # assert is_frozen(self.teacher) copied_decoder_layers
pad_token_id = self.tokenizer.pad_token_id pad_token_id = self.tokenizer.pad_token_id
input_ids, src_mask, tgt_ids = batch["input_ids"], batch["attention_mask"], batch["labels"] input_ids, src_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"]
decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) if isinstance(self.model, T5ForConditionalGeneration):
decoder_input_ids = self.model._shift_right(labels)
else:
decoder_input_ids = shift_tokens_right(labels, pad_token_id)
# noinspection PyCallingNonCallable # noinspection PyCallingNonCallable
lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self( lm_logits, dec_hidden, enc_outputs, enc_hidden_state = self(
input_ids, input_ids,
...@@ -183,11 +136,11 @@ class BartSummarizationDistiller(SummarizationModule): ...@@ -183,11 +136,11 @@ class BartSummarizationDistiller(SummarizationModule):
if self.hparams.label_smoothing == 0: if self.hparams.label_smoothing == 0:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id # Same behavior as modeling_bart.py, besides ignoring pad_token_id
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), tgt_ids.view(-1)) student_lm_loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
else: else:
lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1) lprobs = torch.nn.functional.log_softmax(lm_logits, dim=-1)
student_lm_loss, _ = label_smoothed_nll_loss( student_lm_loss, _ = label_smoothed_nll_loss(
lprobs, tgt_ids, self.hparams.label_smoothing, ignore_index=pad_token_id lprobs, labels, self.hparams.label_smoothing, ignore_index=pad_token_id
) )
def zero_tensor(): def zero_tensor():
...@@ -196,15 +149,14 @@ class BartSummarizationDistiller(SummarizationModule): ...@@ -196,15 +149,14 @@ class BartSummarizationDistiller(SummarizationModule):
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor() loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
if self.different_encoder: if self.different_encoder:
with torch.no_grad(): with torch.no_grad():
teacher_enc_outputs, teacher_enc_hid, _ = self.teacher.model.encoder( teacher_enc_outputs, teacher_enc_hid, _ = self.teacher.get_encoder()(
input_ids, attention_mask=src_mask, output_hidden_states=True input_ids, attention_mask=src_mask, output_hidden_states=True
) )
# DEPRECATE THIS
if self.hparams.alpha_encoder_loss > 0: if self.hparams.alpha_encoder_loss > 0:
loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, src_mask) loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, src_mask)
hid_loss_enc = self.calc_hidden_loss( hid_loss_enc = self.calc_hidden_loss(src_mask, enc_hidden_state, teacher_enc_hid, self.e_layer_ids)
src_mask, enc_hidden_state, teacher_enc_hid, self.hparams.e_layer_to_copy
)
teacher_enc_outputs = (enc_outputs,) teacher_enc_outputs = (enc_outputs,)
assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs) assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs)
...@@ -215,13 +167,15 @@ class BartSummarizationDistiller(SummarizationModule): ...@@ -215,13 +167,15 @@ class BartSummarizationDistiller(SummarizationModule):
attention_mask=src_mask, attention_mask=src_mask,
encoder_outputs=teacher_enc_outputs, encoder_outputs=teacher_enc_outputs,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
lm_labels=tgt_ids, lm_labels=labels,
output_hidden_states=True, output_hidden_states=True,
) )
dec_mask = decoder_input_ids.ne(pad_token_id) dec_mask = decoder_input_ids.ne(pad_token_id)
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, lm_logits, tlogits) loss_ce = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
if self.alpha_hid > 0: if self.alpha_hid > 0: # Intermediate supervision of decoder hidden states
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_matches) hid_loss_dec = self.calc_hidden_loss(
dec_mask, dec_hidden, tdec_hidden, self.d_matches, normalize_hidden=self.hparams.normalize_hidden
)
blended_loss = ( blended_loss = (
self.alpha_ce * loss_ce self.alpha_ce * loss_ce
...@@ -231,7 +185,9 @@ class BartSummarizationDistiller(SummarizationModule): ...@@ -231,7 +185,9 @@ class BartSummarizationDistiller(SummarizationModule):
) )
return blended_loss, loss_ce, student_lm_loss, loss_encoder, hid_loss_enc, hid_loss_dec return blended_loss, loss_ce, student_lm_loss, loss_encoder, hid_loss_enc, hid_loss_dec
def calc_hidden_loss(self, attention_mask, hidden_states, hidden_states_T, matches): @staticmethod
def calc_hidden_loss(attention_mask, hidden_states, hidden_states_T, matches, normalize_hidden):
"""MSE(student_hid, teacher_hid[matches]). Called "Intermediate supervision" in paper. Inspired by TinyBERT."""
msg = "expected list or tuple for hidden_states, got tensor of shape: " msg = "expected list or tuple for hidden_states, got tensor of shape: "
assert not isinstance(hidden_states, torch.Tensor), f"{msg}{hidden_states.shape}" assert not isinstance(hidden_states, torch.Tensor), f"{msg}{hidden_states.shape}"
assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}" assert not isinstance(hidden_states_T, torch.Tensor), f"{msg}{hidden_states_T.shape}"
...@@ -239,7 +195,7 @@ class BartSummarizationDistiller(SummarizationModule): ...@@ -239,7 +195,7 @@ class BartSummarizationDistiller(SummarizationModule):
valid_count = mask.sum() * hidden_states[0].size(-1) valid_count = mask.sum() * hidden_states[0].size(-1)
student_states = torch.stack([hidden_states[i] for i in range(len(matches))]) student_states = torch.stack([hidden_states[i] for i in range(len(matches))])
teacher_states = torch.stack([hidden_states_T[j] for j in matches]) teacher_states = torch.stack([hidden_states_T[j] for j in matches])
if self.hparams.normalize_hidden: if normalize_hidden:
student_states = F.layer_norm(student_states, student_states.shape[1:]) student_states = F.layer_norm(student_states, student_states.shape[1:])
teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:]) teacher_states = F.layer_norm(teacher_states, teacher_states.shape[1:])
mse = F.mse_loss(student_states, teacher_states, reduction="none") mse = F.mse_loss(student_states, teacher_states, reduction="none")
...@@ -287,130 +243,9 @@ class BartTranslationDistiller(BartSummarizationDistiller): ...@@ -287,130 +243,9 @@ class BartTranslationDistiller(BartSummarizationDistiller):
return parser return parser
class T5SummarizationDistiller(BartSummarizationDistiller):
def pre_init(self, hparams):
raise NotImplementedError("T5 Distillation does not work yet")
self.output_dir = Path(hparams.output_dir)
self.output_dir.mkdir(exist_ok=True)
teacher = T5ForConditionalGeneration.from_pretrained(hparams.teacher)
n_layer = hparams.student_decoder_layers
assert n_layer == hparams.student_encoder_layers # TODO(SS): relax this constraint so that we can do 12-6.
d_layers_to_copy = get_layers_to_copy(n_layer, len(teacher.decoder.block))
e_layers_to_copy: List = get_layers_to_copy(n_layer, len(teacher.encoder.block))
student_updates = {"num_layers": n_layer}
hparams.d_layer_to_copy = d_layers_to_copy
hparams.e_layer_to_copy = e_layers_to_copy
kw = teacher.config.to_diff_dict()
kw.update(student_updates)
# Copy weights
student_cfg = T5Config(**kw)
student = T5ForConditionalGeneration(student_cfg)
student, _ = init_student(student, teacher)
self.copy_to_student(d_layers_to_copy, e_layers_to_copy, hparams, student, teacher)
Path(hparams.output_dir).mkdir(exist_ok=True)
task_specific_params = student.config.task_specific_params
if task_specific_params is not None:
student.config.update(task_specific_params.get("summarization", {})) # TODO: dont hardcode
save_dir = self.output_dir.joinpath("student")
save_dir.mkdir(exist_ok=True)
student.save_pretrained(save_dir)
hparams.model_name_or_path = str(save_dir)
return student, student_cfg, teacher
def freeze_embeds(self):
freeze_params(self.model.shared)
for d in [self.model.encoder, self.model.decoder]:
freeze_params(d.embed_tokens)
def sanity_check_gradients(self):
"""T5"""
assert_all_frozen(self.teacher)
assert_all_frozen(self.model.decoder.embed_tokens)
assert_all_frozen(self.model.encoder.embed_tokens)
if self.different_encoder:
assert any_requires_grad(self.model.encoder)
else:
freeze_params(self.model.encoder)
del self.teacher.model.encoder
if self.different_decoder:
assert any_requires_grad(self.model.decoder)
else:
freeze_params(self.model.decoder) # TODO(SS): very suspicious
def _step(self, batch):
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
decoder_input_ids = y[:, :-1].contiguous()
labels = y[:, 1:].clone()
labels[y[:, 1:] == pad_token_id] = -100
# noinspection PyCallingNonCallable
dec_mask = decoder_input_ids.ne(pad_token_id)
sloss, slogits, dec_hidden, enc_outputs, enc_hidden_state = self(
source_ids,
attention_mask=source_mask,
decoder_input_ids=decoder_input_ids,
labels=labels,
output_hidden_states=True,
output_attentions=False,
use_cache=False,
)
def zero_tensor():
return torch.tensor(0.0).type_as(sloss)
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
if self.different_encoder:
with torch.no_grad():
teacher_enc_outputs, teacher_enc_hid = self.teacher.encoder(
source_ids,
attention_mask=source_mask,
output_hidden_states=True,
use_cache=False,
)
if self.hparams.alpha_encoder_loss > 0:
loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, source_mask)
hid_loss_enc = self.calc_hidden_loss(
source_mask, enc_hidden_state, teacher_enc_hid, self.hparams.e_layer_to_copy
)
teacher_enc_outputs = (enc_outputs,)
assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs)
with torch.no_grad():
tloss, tlogits, tdec_hidden, _ = self.teacher(
source_ids,
attention_mask=source_mask,
encoder_outputs=teacher_enc_outputs,
decoder_input_ids=decoder_input_ids,
labels=labels,
output_hidden_states=True,
use_cache=False,
)
loss_ce, s_logits_slct, t_logits_slct = self.calc_ce_loss(dec_mask, slogits, tlogits)
if self.alpha_hid > 0:
hid_loss_dec = self.calc_hidden_loss(dec_mask, dec_hidden, tdec_hidden, self.hparams.d_matches)
blended_loss = (
self.alpha_ce * loss_ce
+ self.alpha_mlm * sloss
+ self.hparams.alpha_encoder_loss * loss_encoder
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
)
return blended_loss, loss_ce, sloss, loss_encoder, hid_loss_enc, hid_loss_dec
def create_module(args): def create_module(args):
t5 = "t5" in args.model_name_or_path
if args.no_teacher: if args.no_teacher:
module_cls = TranslationModule if "translation" in args.task else SummarizationModule module_cls = TranslationModule if "translation" in args.task else SummarizationModule
elif t5: # DISTILL T5 WITH TEACHER FOR SUMMARIZATION
assert "translation" not in args.task, "t5 translation distillation not supported"
module_cls = T5SummarizationDistiller
else: # DISTILL WITH TEACHER else: # DISTILL WITH TEACHER
module_cls = BartTranslationDistiller if "translation" in args.task else BartSummarizationDistiller module_cls = BartTranslationDistiller if "translation" in args.task else BartSummarizationDistiller
args.setup_cls: str = module_cls.__name__ args.setup_cls: str = module_cls.__name__
...@@ -443,56 +278,6 @@ def evaluate_checkpoint(ckpt_path: Path, dest_dir=None): ...@@ -443,56 +278,6 @@ def evaluate_checkpoint(ckpt_path: Path, dest_dir=None):
trainer.test(model) trainer.test(model)
LAYERS_TO_COPY = {
# maps num layers in student -> which teacher layers to copy.
# 12: bart, 16: pegasus, 6: marian/Helsinki-NLP
12: {
1: [0],
2: [0, 6],
3: [0, 6, 11],
4: [0, 4, 8, 11],
6: [0, 2, 4, 7, 9, 11],
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
12: list(range(12)),
},
16: { # maps num layers in student -> which teacher layers to copy
1: [0],
2: [0, 8],
3: [0, 8, 15],
4: [0, 5, 10, 15],
6: [0, 3, 6, 9, 12, 15],
8: [0, 2, 4, 6, 8, 10, 12, 15],
9: [0, 1, 3, 5, 7, 9, 11, 13, 15],
12: [0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 15],
16: list(range(16)),
},
6: {1: [0], 2: [0, 5], 3: [0, 2, 5], 4: [0, 1, 3, 5], 6: list(range(6))},
}
LAYERS_TO_SUPERVISE = {
12: {1: [11], 2: [5, 11], 3: [3, 7, 11], 6: [1, 3, 5, 8, 10, 11]},
16: {1: [15], 4: [4, 9, 12, 15], 8: [1, 3, 5, 7, 9, 11, 13, 15]},
6: {1: [5], 2: [3, 5], 3: [1, 4, 5], 4: [1, 2, 4, 5]},
2: {1: [1], 2: [0, 1]},
}
def get_layers_to_supervise(n_student, n_teacher):
return LAYERS_TO_SUPERVISE[n_teacher][n_student]
def get_layers_to_copy(n_student, n_teacher):
try:
val = LAYERS_TO_COPY[n_teacher][n_student]
assert len(LAYERS_TO_SUPERVISE[n_teacher][n_student]) == len(val) == n_student
return val
except KeyError:
if n_student != n_teacher:
warnings.warn(
f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}"
)
return list(range(n_student))
def distill_main(args): def distill_main(args):
Path(args.output_dir).mkdir(exist_ok=True) Path(args.output_dir).mkdir(exist_ok=True)
if len(os.listdir(args.output_dir)) > 3 and args.do_train: if len(os.listdir(args.output_dir)) > 3 and args.do_train:
......
from typing import List
from torch import nn
def init_student(student, teacher):
teacher_state_dict = teacher.state_dict()
info = student.load_state_dict(teacher_state_dict, strict=False)
assert info.missing_keys == [], info.missing_keys
return student, info
def copy_decoder_layers(teacher, student, l2copy=[0, 2, 4, 7, 9, 11]):
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, l2copy)
def copy_layers(teacher_layers: nn.ModuleList, student_layers: nn.ModuleList, layers_to_copy: List) -> None:
layers_to_copy = nn.ModuleList([l for i, l in enumerate(teacher_layers) if i in layers_to_copy])
assert len(student_layers) == len(layers_to_copy), f"{len(student_layers)} != {len(layers_to_copy)}"
student_layers.load_state_dict(layers_to_copy.state_dict())
import warnings
from pathlib import Path
from typing import List, Tuple, Union
import fire
from torch import nn
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, PreTrainedModel
from transformers.utils import logging
logger = logging.get_logger(__name__)
def copy_layers(src_layers: nn.ModuleList, dest_layers: nn.ModuleList, layers_to_copy: List[int]) -> None:
layers_to_copy = nn.ModuleList([l for i, l in enumerate(src_layers) if i in layers_to_copy])
assert len(dest_layers) == len(layers_to_copy), f"{len(dest_layers)} != {len(layers_to_copy)}"
dest_layers.load_state_dict(layers_to_copy.state_dict())
LAYERS_TO_COPY = {
# maps num layers in teacher -> num_layers in student -> which teacher layers to copy.
# 12: bart, 16: pegasus, 6: marian/Helsinki-NLP
12: {
1: [0], # This says that if the teacher has 12 layers and the student has 1, copy layer 0 of the teacher
2: [0, 6],
3: [0, 6, 11],
4: [0, 4, 8, 11],
6: [0, 2, 4, 7, 9, 11],
9: [0, 1, 2, 4, 5, 7, 9, 10, 11],
12: list(range(12)),
},
16: { # maps num layers in student -> which teacher layers to copy
1: [0],
2: [0, 8],
3: [0, 8, 15],
4: [0, 5, 10, 15],
6: [0, 3, 6, 9, 12, 15],
8: [0, 2, 4, 6, 8, 10, 12, 15],
9: [0, 1, 3, 5, 7, 9, 11, 13, 15],
12: [0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 15],
16: list(range(16)),
},
6: {1: [0], 2: [0, 5], 3: [0, 2, 5], 4: [0, 1, 3, 5], 6: list(range(6))},
}
LAYERS_TO_SUPERVISE = {
# maps num layers in student -> which teacher layers to copy.
6: {1: [5], 2: [3, 5], 3: [1, 4, 5], 4: [1, 2, 4, 5]},
12: {1: [11], 2: [5, 11], 3: [3, 7, 11], 6: [1, 3, 5, 8, 10, 11]},
16: {1: [15], 4: [4, 9, 12, 15], 8: [1, 3, 5, 7, 9, 11, 13, 15]},
}
def pick_layers_to_copy(n_student, n_teacher):
try:
val = LAYERS_TO_COPY[n_teacher][n_student]
return val
except KeyError:
if n_student != n_teacher:
warnings.warn(
f"no hardcoded layers to copy for teacher {n_teacher} -> student {n_student}, defaulting to first {n_student}"
)
return list(range(n_student))
def get_layers_to_supervise(n_student, n_teacher) -> List[int]:
"""Used or the --supervise_forward kwarg"""
if n_student > n_teacher:
raise ValueError(f"Cannot perform intermediate supervision for student {n_student} > teacher {n_teacher}")
elif n_teacher == n_student:
return list(range(n_teacher))
elif n_student == 1:
return [n_teacher - 1]
else:
return LAYERS_TO_SUPERVISE[n_teacher][n_student]
def create_student_by_copying_alternating_layers(
teacher: Union[str, PreTrainedModel],
save_path: Union[str, Path] = "student",
e: Union[int, None] = None,
d: Union[int, None] = None,
copy_first_teacher_layers=False,
**extra_config_kwargs
) -> Tuple[PreTrainedModel, List[int], List[int]]:
"""Make a student by copying alternating layers from a teacher, save it to save_path.
Args:
teacher: str or PreTrainedModel if str, this will call AutoModelForSeq2SeqLM.from_pretrained(teacher) before
copying layers
save_path: where to save the student, defaults to student directory.
e: how many Encoder layers should the student have, default is fully copy of teacher
d: how many Decoder layers should the student have, default is fully copy of teacher
copy_first_teacher_layers: [bool] dont copy alternating layers, just the first e/d.
**extra_config_kwargs: extra kwargs to pass to the student, by default the teacher config is used.
Returns:
student: new, smaller model. (Also saves it to save_path)
e_layers_to_copy: list of which teacher encoder layers were used
d_layers_to_copy: list of which teacher decoder layers were used
"""
_msg = "encoder_layers and decoder_layers cannot be both None-- you would just have an identical teacher."
assert (e is not None) or (d is not None), _msg
if isinstance(teacher, str):
AutoTokenizer.from_pretrained(teacher).save_pretrained(save_path) # purely for convenience
teacher = AutoModelForSeq2SeqLM.from_pretrained(teacher).eval()
else:
assert isinstance(teacher, PreTrainedModel), f"teacher must be a model or string got type {type(teacher)}"
init_kwargs = teacher.config.to_diff_dict()
try:
teacher_e, teacher_d = teacher.config.encoder_layers, teacher.config.decoder_layers
if e is None:
e = teacher_e
if d is None:
d = teacher_d
init_kwargs.update({"encoder_layers": e, "decoder_layers": d})
except AttributeError: # T5
teacher_e, teacher_d = teacher.config.num_layers, teacher.config.num_hidden_layers
assert e == d, "T5 Students must be symmetric"
init_kwargs["num_layers"] = e
# Kwargs to instantiate student = teacher kwargs with updated layer numbers + **extra_config_kwargs
init_kwargs.update(extra_config_kwargs)
# Copy weights
student_cfg = teacher.config_class(**init_kwargs)
student = AutoModelForSeq2SeqLM.from_config(student_cfg)
# Start by copying the full teacher state dict this will copy the first N teacher layers to the student.
info = student.load_state_dict(teacher.state_dict(), strict=False)
assert info.missing_keys == [], info.missing_keys # every student key should have a teacher keys.
if copy_first_teacher_layers: # Our copying is done. We just log and save
e_layers_to_copy, d_layers_to_copy = list(range(e)), list(range(d))
logger.info(
f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to {save_path}"
)
student.save_pretrained(save_path)
return student, e_layers_to_copy, d_layers_to_copy
# Decide which layers of the teacher to copy. Not exactly alternating -- we try to keep first and last layer.
e_layers_to_copy: List[int] = pick_layers_to_copy(e, teacher_e)
d_layers_to_copy: List[int] = pick_layers_to_copy(d, teacher_d)
try:
copy_layers(teacher.model.encoder.layers, student.model.encoder.layers, e_layers_to_copy)
copy_layers(teacher.model.decoder.layers, student.model.decoder.layers, d_layers_to_copy)
except AttributeError: # For t5, student.model.encoder.layers is called student.encoder.block
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
copy_layers(teacher.decoder.block, student.decoder.block, d_layers_to_copy)
logger.info(
f"Copied encoder layers {e_layers_to_copy} and decoder layers {d_layers_to_copy}. Saving them to {save_path}"
)
student.config.init_metadata = dict(
teacher_type=teacher.config.model_type,
copied_encoder_layers=e_layers_to_copy,
copied_decoder_layers=d_layers_to_copy,
)
student.save_pretrained(save_path)
# Save information about copying for easier reproducibility
return student, e_layers_to_copy, d_layers_to_copy
if __name__ == "__main__":
fire.Fire(create_student_by_copying_alternating_layers)
#!/usr/bin/env python
import fire
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer
def save_randomly_initialized_version(config_name: str, save_dir: str, **config_kwargs):
"""Save a randomly initialized version of a model using a pretrained config.
Args:
config_name: which config to use
save_dir: where to save the resulting model and tokenizer
config_kwargs: Passed to AutoConfig
Usage::
save_randomly_initialized_version("facebook/bart-large-cnn", "distilbart_random_cnn_6_3", encoder_layers=6, decoder_layers=3, num_beams=3)
"""
cfg = AutoConfig.from_pretrained(config_name, **config_kwargs)
model = AutoModelForSeq2SeqLM.from_config(cfg)
model.save_pretrained(save_dir)
AutoTokenizer.from_pretrained(config_name).save_pretrained(save_dir)
return model
if __name__ == "__main__":
fire.Fire(save_randomly_initialized_version)
import tempfile
import unittest
from make_student import create_student_by_copying_alternating_layers
from transformers import AutoConfig
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch
TINY_BART = "sshleifer/bart-tiny-random"
TINY_T5 = "patrickvonplaten/t5-tiny-random"
@require_torch
class MakeStudentTester(unittest.TestCase):
@cached_property
def teacher_config(self):
return AutoConfig.from_pretrained(TINY_BART)
def test_valid_t5(self):
student, *_ = create_student_by_copying_alternating_layers(TINY_T5, tempfile.mkdtemp(), e=1, d=1)
self.assertEqual(student.config.num_hidden_layers, 1)
def test_invalid_t5(self):
# T5 students must have the same e==d because there is only one config property
with self.assertRaises(AssertionError):
student, *_ = create_student_by_copying_alternating_layers(TINY_T5, tempfile.mkdtemp(), e=1, d=None)
def test_same_decoder_small_encoder(self):
student, *_ = create_student_by_copying_alternating_layers(TINY_BART, tempfile.mkdtemp(), e=1, d=None)
self.assertEqual(student.config.encoder_layers, 1)
self.assertEqual(student.config.decoder_layers, self.teacher_config.encoder_layers)
def test_small_enc_small_dec(self):
student, *_ = create_student_by_copying_alternating_layers(TINY_BART, tempfile.mkdtemp(), e=1, d=1)
self.assertEqual(student.config.encoder_layers, 1)
self.assertEqual(student.config.decoder_layers, 1)
def test_raises_assert(self):
with self.assertRaises(AssertionError):
create_student_by_copying_alternating_layers(TINY_BART, tempfile.mkdtemp(), e=None, d=None)
#!/usr/bin/env bash #!/usr/bin/env bash
export PYTHONPATH="../":"${PYTHONPATH}" export PYTHONPATH="../":"${PYTHONPATH}"
export BS=16
export GAS=2
python distillation.py \ python distillation.py \
--teacher facebook/bart-large-xsum --data_dir xsum \
--student_decoder_layers 6 --student_encoder_layers 12 \
--freeze_encoder --freeze_embeds \
--learning_rate=3e-4 \ --learning_rate=3e-4 \
--do_train \ --do_train \
--do_predict \ --do_predict \
--fp16 \ --fp16 --fp16_opt_level=O1 \
--val_check_interval 0.1 --n_val 1000 \ --val_check_interval 0.1 --n_val 1000 --eval_beams 2 --length_penalty=0.5 \
--teacher facebook/bart-large-xsum --data_dir $XSUM_DIR \
--max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \ --max_target_length=60 --val_max_target_length=60 --test_max_target_length=100 \
--student_decoder_layers 6 --student_encoder_layers 12 \
--freeze_encoder --freeze_embeds \
--model_name_or_path IGNORED \ --model_name_or_path IGNORED \
--alpha_hid=3. --length_penalty=0.5 \ --alpha_hid=3. \
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps=$GAS --num_train_epochs=6 \ --train_batch_size=16 --eval_batch_size=16 --gradient_accumulation_steps=2 \
--tokenizer_name facebook/bart-large \ --sortish_sampler \
--num_train_epochs=6 \
--warmup_steps 500 \ --warmup_steps 500 \
--output_dir distilbart_xsum_12_6 \ --output_dir distilbart_xsum_12_6 \
"$@" "$@"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment