Unverified Commit d996024a authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Use compute_loss in prediction_step (#9935)

parent aa438a42
...@@ -1312,7 +1312,7 @@ class Trainer: ...@@ -1312,7 +1312,7 @@ class Trainer:
return loss.detach() return loss.detach()
def compute_loss(self, model, inputs): def compute_loss(self, model, inputs, return_outputs=False):
""" """
How the loss is computed by Trainer. By default, all models return the loss in the first element. How the loss is computed by Trainer. By default, all models return the loss in the first element.
...@@ -1329,10 +1329,12 @@ class Trainer: ...@@ -1329,10 +1329,12 @@ class Trainer:
self._past = outputs[self.args.past_index] self._past = outputs[self.args.past_index]
if labels is not None: if labels is not None:
return self.label_smoother(outputs, labels) loss = self.label_smoother(outputs, labels)
else: else:
# We don't use .loss here since the model may return tuples instead of ModelOutput. # We don't use .loss here since the model may return tuples instead of ModelOutput.
return outputs["loss"] if isinstance(outputs, dict) else outputs[0] loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
return (loss, outputs) if return_outputs else loss
def is_local_process_zero(self) -> bool: def is_local_process_zero(self) -> bool:
""" """
...@@ -1718,29 +1720,27 @@ class Trainer: ...@@ -1718,29 +1720,27 @@ class Trainer:
ignore_keys = [] ignore_keys = []
with torch.no_grad(): with torch.no_grad():
if self.use_amp:
with autocast():
outputs = model(**inputs)
else:
outputs = model(**inputs)
if has_labels: if has_labels:
if self.label_smoother is not None and "labels" in inputs: loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() loss = loss.mean().detach()
else:
loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
if isinstance(outputs, dict): if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
else: else:
logits = outputs[1:] logits = outputs[1:]
else: else:
loss = None loss = None
if self.use_amp:
with autocast():
outputs = model(**inputs)
else:
outputs = model(**inputs)
if isinstance(outputs, dict): if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
else: else:
logits = outputs logits = outputs
# TODO: this needs to be fixed and made cleaner later. # TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0: if self.args.past_index >= 0:
self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1] self._past = outputs[self.args.past_index - 1]
if prediction_loss_only: if prediction_loss_only:
return (loss, None, None) return (loss, None, None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment