Unverified Commit d86d57fa authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[s2s] distillation apex breaks return_dict obj (#8631)

* apex breaks return_dict obj

* style
parent bf3611b2
......@@ -154,7 +154,7 @@ class SummarizationDistiller(SummarizationModule):
output_attentions=False,
use_cache=False,
)
lm_logits = student_outputs.logits
lm_logits = student_outputs["logits"]
# Same cross entropy vs. label smoothing logic as finetune.py
assert lm_logits.shape[-1] == self.model.config.vocab_size
......@@ -171,7 +171,9 @@ class SummarizationDistiller(SummarizationModule):
def zero_tensor():
return torch.tensor(0.0).type_as(student_lm_loss)
teacher_enc_outputs = student_outputs.encoder_last_hidden_state # use this unless self.different_base_models
teacher_enc_outputs = student_outputs[
"encoder_last_hidden_state"
] # use this unless self.different_base_models
hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
if self.different_encoder: # compute encoder hidden state loss
all_teacher_encoder_outputs = self.teacher.get_encoder()(
......@@ -180,12 +182,12 @@ class SummarizationDistiller(SummarizationModule):
output_hidden_states=self.do_calc_hidden_loss,
)
if self.different_base_models:
teacher_enc_outputs = all_teacher_encoder_outputs.last_hidden_state
teacher_enc_outputs = all_teacher_encoder_outputs["last_hidden_state"]
elif self.do_calc_hidden_loss:
hid_loss_enc = self.calc_hidden_loss(
src_mask,
student_outputs.encoder_hidden_states,
all_teacher_encoder_outputs.hidden_states,
student_outputs["encoder_hidden_states"],
all_teacher_encoder_outputs["hidden_states"],
self.e_matches,
normalize_hidden=self.hparams.normalize_hidden,
)
......@@ -199,12 +201,12 @@ class SummarizationDistiller(SummarizationModule):
use_cache=False, # since we are not passing labels, never let this default to True
)
dec_mask = decoder_input_ids.ne(pad_token_id)
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, teacher_outputs.logits)
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, teacher_outputs["logits"])
if self.do_calc_hidden_loss: # Intermediate supervision of decoder hidden states
hid_loss_dec = self.calc_hidden_loss(
dec_mask,
student_outputs.decoder_hidden_states,
teacher_outputs.decoder_hidden_states,
student_outputs["decoder_hidden_states"],
teacher_outputs["decoder_hidden_states"],
self.d_matches,
normalize_hidden=self.hparams.normalize_hidden,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment