Unverified Commit cdf1b7ae authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

fix to adjust for #8530 changes (#8612)

parent 2819da02
...@@ -148,7 +148,7 @@ class SummarizationModule(BaseTransformer): ...@@ -148,7 +148,7 @@ class SummarizationModule(BaseTransformer):
self.save_readable_batch(batch) self.save_readable_batch(batch)
outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False) outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
lm_logits = outputs[0] lm_logits = outputs["logits"]
if self.hparams.label_smoothing == 0: if self.hparams.label_smoothing == 0:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id # Same behavior as modeling_bart.py, besides ignoring pad_token_id
ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment