Unverified Commit 6890d196 authored by seungeunrho's avatar seungeunrho Committed by GitHub
Browse files

Shifting labels for causal LM when using label smoother (#17987)



* Shifting labels for causal LM when using label smoother

When training CausalLM, loss is computed within model's foward() function and
labels are shifted internally. However, if label smoothing is applied, loss is
computed in trainer's compute_loss function and labels are not shifted.
This causes unintended confusion during the alignment of labels and corresponding
inputs. This commit is for resolving this confusion.

Resolves #17960

On branch shift_labels_for_causalLM
Changes to be committed:
	modified:   src/transformers/trainer.py
	modified:   src/transformers/trainer_pt_utils.py

* Update trainer.py

* Update src/transformers/trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 6f0723a9
...@@ -69,6 +69,7 @@ from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled ...@@ -69,6 +69,7 @@ from .deepspeed import deepspeed_init, is_deepspeed_zero3_enabled
from .dependency_versions_check import dep_version_check from .dependency_versions_check import dep_version_check
from .modelcard import TrainingSummary from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from .optimization import Adafactor, get_scheduler from .optimization import Adafactor, get_scheduler
from .tokenization_utils_base import PreTrainedTokenizerBase from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import ( from .trainer_callback import (
...@@ -2384,7 +2385,10 @@ class Trainer: ...@@ -2384,7 +2385,10 @@ class Trainer:
self._past = outputs[self.args.past_index] self._past = outputs[self.args.past_index]
if labels is not None: if labels is not None:
loss = self.label_smoother(outputs, labels) if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
loss = self.label_smoother(outputs, labels)
else: else:
# We don't use .loss here since the model may return tuples instead of ModelOutput. # We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
......
...@@ -466,8 +466,12 @@ class LabelSmoother: ...@@ -466,8 +466,12 @@ class LabelSmoother:
epsilon: float = 0.1 epsilon: float = 0.1
ignore_index: int = -100 ignore_index: int = -100
def __call__(self, model_output, labels): def __call__(self, model_output, labels, shift_labels=False):
logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0] logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]
if shift_labels:
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
log_probs = -nn.functional.log_softmax(logits, dim=-1) log_probs = -nn.functional.log_softmax(logits, dim=-1)
if labels.dim() == log_probs.dim() - 1: if labels.dim() == log_probs.dim() - 1:
labels = labels.unsqueeze(-1) labels = labels.unsqueeze(-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment