Unverified Commit 0d0d7769 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Allow trainer to return eval. loss for CLIP-like models (#20214)



* Allow trainer to return loss for CLIP-like models

* Apply suggestions

* update

* update

* update
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 822ae69c
...@@ -134,6 +134,7 @@ from .utils import ( ...@@ -134,6 +134,7 @@ from .utils import (
CONFIG_NAME, CONFIG_NAME,
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
can_return_loss,
find_labels, find_labels,
get_full_repo_name, get_full_repo_name,
is_apex_available, is_apex_available,
...@@ -625,6 +626,7 @@ class Trainer: ...@@ -625,6 +626,7 @@ class Trainer:
self.use_tune_checkpoints = False self.use_tune_checkpoints = False
default_label_names = find_labels(self.model.__class__) default_label_names = find_labels(self.model.__class__)
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
self.can_return_loss = can_return_loss(self.model.__class__)
self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
# Internal variables to keep track of the original batch size # Internal variables to keep track of the original batch size
...@@ -3190,6 +3192,14 @@ class Trainer: ...@@ -3190,6 +3192,14 @@ class Trainer:
logits and labels (each being optional). logits and labels (each being optional).
""" """
has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names) has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
# For CLIP-like models capable of returning loss values.
# If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
# is `True` in `model.forward`.
return_loss = inputs.get("return_loss", None)
if return_loss is None:
return_loss = self.can_return_loss
loss_without_labels = True if len(self.label_names) == 0 and return_loss else False
inputs = self._prepare_inputs(inputs) inputs = self._prepare_inputs(inputs)
if ignore_keys is None: if ignore_keys is None:
if hasattr(self.model, "config"): if hasattr(self.model, "config"):
...@@ -3198,7 +3208,7 @@ class Trainer: ...@@ -3198,7 +3208,7 @@ class Trainer:
ignore_keys = [] ignore_keys = []
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first. # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
if has_labels: if has_labels or loss_without_labels:
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
if len(labels) == 1: if len(labels) == 1:
labels = labels[0] labels = labels[0]
...@@ -3208,7 +3218,7 @@ class Trainer: ...@@ -3208,7 +3218,7 @@ class Trainer:
with torch.no_grad(): with torch.no_grad():
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
raw_outputs = smp_forward_only(model, inputs) raw_outputs = smp_forward_only(model, inputs)
if has_labels: if has_labels or loss_without_labels:
if isinstance(raw_outputs, dict): if isinstance(raw_outputs, dict):
loss_mb = raw_outputs["loss"] loss_mb = raw_outputs["loss"]
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"]) logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
...@@ -3226,7 +3236,7 @@ class Trainer: ...@@ -3226,7 +3236,7 @@ class Trainer:
logits_mb = raw_outputs logits_mb = raw_outputs
logits = smp_nested_concat(logits_mb) logits = smp_nested_concat(logits_mb)
else: else:
if has_labels: if has_labels or loss_without_labels:
with self.compute_loss_context_manager(): with self.compute_loss_context_manager():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True) loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach() loss = loss.mean().detach()
......
...@@ -38,6 +38,7 @@ from .generic import ( ...@@ -38,6 +38,7 @@ from .generic import (
PaddingStrategy, PaddingStrategy,
TensorType, TensorType,
cached_property, cached_property,
can_return_loss,
expand_dims, expand_dims,
find_labels, find_labels,
flatten_dict, flatten_dict,
......
...@@ -336,6 +336,28 @@ class ContextManagers: ...@@ -336,6 +336,28 @@ class ContextManagers:
self.stack.__exit__(*args, **kwargs) self.stack.__exit__(*args, **kwargs)
def can_return_loss(model_class):
"""
Check if a given model can return loss.
Args:
model_class (`type`): The class of the model.
"""
model_name = model_class.__name__
if model_name.startswith("TF"):
signature = inspect.signature(model_class.call)
elif model_name.startswith("Flax"):
signature = inspect.signature(model_class.__call__)
else:
signature = inspect.signature(model_class.forward)
for p in signature.parameters:
if p == "return_loss" and signature.parameters[p].default is True:
return True
return False
def find_labels(model_class): def find_labels(model_class):
""" """
Find the labels used by a given model. Find the labels used by a given model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment