Unverified Commit d5d2744a authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Support T5 Distillation w/hidden state supervision (#7599)

parent 818c294f
...@@ -28,7 +28,7 @@ from lightning_base import generic_train # noqa ...@@ -28,7 +28,7 @@ from lightning_base import generic_train # noqa
class BartSummarizationDistiller(SummarizationModule): class BartSummarizationDistiller(SummarizationModule):
"""Supports Bart, Pegasus and other models that inherit from Bart.""" """Supports Bart, Pegasus and other models that inherit from Bart."""
loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"] loss_names = ["loss", "ce_loss", "mlm_loss", "hid_loss_enc", "hid_loss_dec"]
def __init__(self, hparams): def __init__(self, hparams):
assert Path(hparams.data_dir).exists() assert Path(hparams.data_dir).exists()
...@@ -46,9 +46,19 @@ class BartSummarizationDistiller(SummarizationModule): ...@@ -46,9 +46,19 @@ class BartSummarizationDistiller(SummarizationModule):
if hparams.length_penalty != -1: if hparams.length_penalty != -1:
student.config.length_penalty = hparams.length_penalty student.config.length_penalty = hparams.length_penalty
super().__init__(hparams, model=student, config=student.config) super().__init__(hparams, model=student, config=student.config)
model_type = student.config.model_type
self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids # type: List[int], List[int] self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids # type: List[int], List[int]
self.different_encoder = hparams.student_encoder_layers != teacher.config.encoder_layers
self.different_decoder = hparams.student_decoder_layers != teacher.config.decoder_layers if model_type == "t5":
teacher_encoder_layers = len(teacher.get_encoder().block)
teacher_decoder_layers = len(teacher.get_decoder().block)
else:
teacher_encoder_layers = teacher.config.encoder_layers
teacher_decoder_layers = teacher.config.decoder_layers
self.different_encoder = hparams.student_encoder_layers != teacher_encoder_layers
self.different_decoder = hparams.student_decoder_layers != teacher_decoder_layers
self.teacher = teacher self.teacher = teacher
freeze_params(self.teacher) freeze_params(self.teacher)
...@@ -59,17 +69,17 @@ class BartSummarizationDistiller(SummarizationModule): ...@@ -59,17 +69,17 @@ class BartSummarizationDistiller(SummarizationModule):
del self.teacher.encoder del self.teacher.encoder
# Intermediate supervision: Decide which layers to supervise # Intermediate supervision: Decide which layers to supervise
if hparams.supervise_forward: if hparams.supervise_forward:
self.d_matches = get_layers_to_supervise( self.e_matches = get_layers_to_supervise(n_student=len(self.e_layer_ids), n_teacher=teacher_encoder_layers)
n_student=len(self.d_layer_ids), n_teacher=self.teacher.config.decoder_layers self.d_matches = get_layers_to_supervise(n_student=len(self.d_layer_ids), n_teacher=teacher_decoder_layers)
) else: # student layer should emulate hidden states of the teacher layer it was copied from
else: self.e_matches = self.e_layer_ids
self.d_matches = self.d_layer_ids self.d_matches = self.d_layer_ids
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean") self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
self.temperature = 2.0 self.temperature = 2.0
self.alpha_mlm = hparams.alpha_mlm self.alpha_mlm = hparams.alpha_mlm
self.alpha_ce = hparams.alpha_ce self.alpha_ce = hparams.alpha_ce
self.alpha_hid = hparams.alpha_hid self.alpha_hid = hparams.alpha_hid
self.alpha_encoder_loss = hparams.alpha_encoder_loss
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -129,7 +139,7 @@ class BartSummarizationDistiller(SummarizationModule): ...@@ -129,7 +139,7 @@ class BartSummarizationDistiller(SummarizationModule):
output_hidden_states=True, output_hidden_states=True,
output_attentions=False, output_attentions=False,
use_cache=False, use_cache=False,
) # TODO(@sshleifer): return_dict=True cleanup )
# Same cross entropy vs. label smoothing logic as finetune.py # Same cross entropy vs. label smoothing logic as finetune.py
assert lm_logits.shape[-1] == self.model.config.vocab_size assert lm_logits.shape[-1] == self.model.config.vocab_size
...@@ -146,30 +156,32 @@ class BartSummarizationDistiller(SummarizationModule): ...@@ -146,30 +156,32 @@ class BartSummarizationDistiller(SummarizationModule):
def zero_tensor(): def zero_tensor():
return torch.tensor(0.0).type_as(student_lm_loss) return torch.tensor(0.0).type_as(student_lm_loss)
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor() hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
if self.different_encoder: if self.different_encoder: # compute encoder hidden state loss
with torch.no_grad(): with torch.no_grad():
teacher_enc_outputs, teacher_enc_hid, _ = self.teacher.get_encoder()( teacher_enc_hid = self.teacher.get_encoder()(
input_ids, attention_mask=src_mask, output_hidden_states=True input_ids, attention_mask=src_mask, output_hidden_states=True, return_dict=True
) ).hidden_states
# DEPRECATE THIS
if self.hparams.alpha_encoder_loss > 0: hid_loss_enc = self.calc_hidden_loss(
loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, src_mask) src_mask,
enc_hidden_state,
hid_loss_enc = self.calc_hidden_loss(src_mask, enc_hidden_state, teacher_enc_hid, self.e_layer_ids) teacher_enc_hid,
self.e_matches,
teacher_enc_outputs = (enc_outputs,) normalize_hidden=self.hparams.normalize_hidden,
assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs) )
with torch.no_grad(): with torch.no_grad():
tloss, tlogits, tdec_hidden, _ = self.teacher( outputs = self.teacher(
input_ids, input_ids,
attention_mask=src_mask, attention_mask=src_mask,
encoder_outputs=teacher_enc_outputs, encoder_outputs=(enc_outputs,),
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
lm_labels=labels, lm_labels=labels,
output_hidden_states=True, output_hidden_states=True,
return_dict=True,
) )
tlogits, tdec_hidden = outputs.logits, outputs.decoder_hidden_states
dec_mask = decoder_input_ids.ne(pad_token_id) dec_mask = decoder_input_ids.ne(pad_token_id)
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, tlogits) loss_ce = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
if self.alpha_hid > 0: # Intermediate supervision of decoder hidden states if self.alpha_hid > 0: # Intermediate supervision of decoder hidden states
...@@ -180,10 +192,9 @@ class BartSummarizationDistiller(SummarizationModule): ...@@ -180,10 +192,9 @@ class BartSummarizationDistiller(SummarizationModule):
blended_loss = ( blended_loss = (
self.alpha_ce * loss_ce self.alpha_ce * loss_ce
+ self.alpha_mlm * student_lm_loss + self.alpha_mlm * student_lm_loss
+ self.hparams.alpha_encoder_loss * loss_encoder
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec) + self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
) )
return blended_loss, loss_ce, student_lm_loss, loss_encoder, hid_loss_enc, hid_loss_dec return blended_loss, loss_ce, student_lm_loss, hid_loss_enc, hid_loss_dec
@staticmethod @staticmethod
def calc_hidden_loss(attention_mask, hidden_states, hidden_states_T, matches, normalize_hidden): def calc_hidden_loss(attention_mask, hidden_states, hidden_states_T, matches, normalize_hidden):
...@@ -207,7 +218,6 @@ def add_distill_args(parser): ...@@ -207,7 +218,6 @@ def add_distill_args(parser):
parser.add_argument("--teacher", type=str) parser.add_argument("--teacher", type=str)
parser.add_argument("--alpha_ce", default=0.8, type=float) parser.add_argument("--alpha_ce", default=0.8, type=float)
parser.add_argument("--alpha_mlm", default=0.2, type=float) parser.add_argument("--alpha_mlm", default=0.2, type=float)
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False) parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False) parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False) parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)
......
...@@ -86,7 +86,6 @@ CHEAP_ARGS = { ...@@ -86,7 +86,6 @@ CHEAP_ARGS = {
"n_val": -1, "n_val": -1,
"n_test": -1, "n_test": -1,
"student_encoder_layers": 1, "student_encoder_layers": 1,
"alpha_encoder_loss": 0.0,
"freeze_encoder": False, "freeze_encoder": False,
"auto_scale_batch_size": False, "auto_scale_batch_size": False,
} }
...@@ -230,7 +229,6 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -230,7 +229,6 @@ class TestSummarizationDistiller(unittest.TestCase):
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp())) evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
@unittest.skip("T5 distillation is broken at the moment")
def test_distill_t5(self): def test_distill_t5(self):
updates = dict( updates = dict(
student_encoder_layers=1, student_encoder_layers=1,
...@@ -255,7 +253,6 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -255,7 +253,6 @@ class TestSummarizationDistiller(unittest.TestCase):
model_name_or_path="sshleifer/tinier_bart", model_name_or_path="sshleifer/tinier_bart",
teacher=CHEAP_ARGS["model_name_or_path"], teacher=CHEAP_ARGS["model_name_or_path"],
val_check_interval=0.5, val_check_interval=0.5,
alpha_encoder_loss=0.4,
) )
default_updates.update(updates) default_updates.update(updates)
args_d: dict = CHEAP_ARGS.copy() args_d: dict = CHEAP_ARGS.copy()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment