Unverified Commit 9bd30f7c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Seq2SeqTrainer] Move import to init to make file self-contained (#8194)

* boom boom

* reverse order
parent 1f12934d
...@@ -20,12 +20,6 @@ from transformers.optimization import ( ...@@ -20,12 +20,6 @@ from transformers.optimization import (
from transformers.trainer_pt_utils import get_tpu_sampler from transformers.trainer_pt_utils import get_tpu_sampler
try:
from .utils import label_smoothed_nll_loss
except ImportError:
from utils import label_smoothed_nll_loss
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
arg_to_scheduler = { arg_to_scheduler = {
...@@ -64,6 +58,17 @@ class Seq2SeqTrainer(Trainer): ...@@ -64,6 +58,17 @@ class Seq2SeqTrainer(Trainer):
f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for padding.." f"The `config.pad_token_id` is `None`. Using `config.eos_token_id` = {self.config.eos_token_id} for padding.."
) )
if self.args.label_smoothing == 0:
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
else:
# dynamically import label_smoothed_nll_loss
try:
from .utils import label_smoothed_nll_loss
except ImportError:
from utils import label_smoothed_nll_loss
self.loss_fn = label_smoothed_nll_loss
def create_optimizer_and_scheduler(self, num_training_steps: int): def create_optimizer_and_scheduler(self, num_training_steps: int):
""" """
Setup the optimizer and the learning rate scheduler. Setup the optimizer and the learning rate scheduler.
...@@ -135,9 +140,7 @@ class Seq2SeqTrainer(Trainer): ...@@ -135,9 +140,7 @@ class Seq2SeqTrainer(Trainer):
if self.data_args is not None and self.data_args.ignore_pad_token_for_loss: if self.data_args is not None and self.data_args.ignore_pad_token_for_loss:
# force training to ignore pad token # force training to ignore pad token
logits = model(**inputs, use_cache=False)[0] logits = model(**inputs, use_cache=False)[0]
loss = self.loss_fn(logits.view(-1, logits.shape[-1]), labels.view(-1))
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
loss = loss_fct(logits.view(-1, logits.shape[-1]), labels.view(-1))
else: else:
# compute usual loss via models # compute usual loss via models
loss, logits = model(**inputs, labels=labels, use_cache=False)[:2] loss, logits = model(**inputs, labels=labels, use_cache=False)[:2]
...@@ -145,9 +148,7 @@ class Seq2SeqTrainer(Trainer): ...@@ -145,9 +148,7 @@ class Seq2SeqTrainer(Trainer):
# compute label smoothed loss # compute label smoothed loss
logits = model(**inputs, use_cache=False)[0] logits = model(**inputs, use_cache=False)[0]
lprobs = torch.nn.functional.log_softmax(logits, dim=-1) lprobs = torch.nn.functional.log_softmax(logits, dim=-1)
loss, _ = label_smoothed_nll_loss( loss, _ = self.loss_fn(lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id)
lprobs, labels, self.args.label_smoothing, ignore_index=self.config.pad_token_id
)
return loss, logits return loss, logits
def compute_loss(self, model, inputs): def compute_loss(self, model, inputs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment