Unverified Commit 0d0d7769 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Allow trainer to return eval. loss for CLIP-like models (#20214)



* Allow trainer to return loss for CLIP-like models

* Apply suggestions

* update

* update

* update
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 822ae69c
......@@ -134,6 +134,7 @@ from .utils import (
CONFIG_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
can_return_loss,
find_labels,
get_full_repo_name,
is_apex_available,
......@@ -625,6 +626,7 @@ class Trainer:
self.use_tune_checkpoints = False
default_label_names = find_labels(self.model.__class__)
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
self.can_return_loss = can_return_loss(self.model.__class__)
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
# Internal variables to keep track of the original batch size
......@@ -3190,6 +3192,14 @@ class Trainer:
logits and labels (each being optional).
"""
has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
# For CLIP-like models capable of returning loss values.
# If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
# is `True` in `model.forward`.
return_loss = inputs.get("return_loss", None)
if return_loss is None:
return_loss = self.can_return_loss
loss_without_labels = True if len(self.label_names) == 0 and return_loss else False
inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
......@@ -3198,7 +3208,7 @@ class Trainer:
ignore_keys = []
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
if has_labels:
if has_labels or loss_without_labels:
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
if len(labels) == 1:
labels = labels[0]
......@@ -3208,7 +3218,7 @@ class Trainer:
with torch.no_grad():
if is_sagemaker_mp_enabled():
raw_outputs = smp_forward_only(model, inputs)
if has_labels:
if has_labels or loss_without_labels:
if isinstance(raw_outputs, dict):
loss_mb = raw_outputs["loss"]
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
......@@ -3226,7 +3236,7 @@ class Trainer:
logits_mb = raw_outputs
logits = smp_nested_concat(logits_mb)
else:
if has_labels:
if has_labels or loss_without_labels:
with self.compute_loss_context_manager():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach()
......
......@@ -38,6 +38,7 @@ from .generic import (
PaddingStrategy,
TensorType,
cached_property,
can_return_loss,
expand_dims,
find_labels,
flatten_dict,
......
......@@ -336,6 +336,28 @@ class ContextManagers:
self.stack.__exit__(*args, **kwargs)
def can_return_loss(model_class):
"""
Check if a given model can return loss.
Args:
model_class (`type`): The class of the model.
"""
model_name = model_class.__name__
if model_name.startswith("TF"):
signature = inspect.signature(model_class.call)
elif model_name.startswith("Flax"):
signature = inspect.signature(model_class.__call__)
else:
signature = inspect.signature(model_class.forward)
for p in signature.parameters:
if p == "return_loss" and signature.parameters[p].default is True:
return True
return False
def find_labels(model_class):
"""
Find the labels used by a given model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment