Unverified Commit 9f1747f9 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Seq2Seq] Correct import in Seq2Seq Trainer (#8254)

parent 504ff7bb
......@@ -62,10 +62,7 @@ class Seq2SeqTrainer(Trainer):
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
else:
# dynamically import label_smoothed_nll_loss
try:
from .utils import label_smoothed_nll_loss
except ImportError:
from utils import label_smoothed_nll_loss
from utils import label_smoothed_nll_loss
self.loss_fn = label_smoothed_nll_loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment