"test/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "a40229f6f8285e814d9ba8f3469eb99968c85b2e"
Unverified Commit d996024a authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Use compute_loss in prediction_step (#9935)

parent aa438a42
......@@ -1312,7 +1312,7 @@ class Trainer:
return loss.detach()
def compute_loss(self, model, inputs):
def compute_loss(self, model, inputs, return_outputs=False):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
......@@ -1329,10 +1329,12 @@ class Trainer:
self._past = outputs[self.args.past_index]
if labels is not None:
return self.label_smoother(outputs, labels)
loss = self.label_smoother(outputs, labels)
else:
# We don't use .loss here since the model may return tuples instead of ModelOutput.
return outputs["loss"] if isinstance(outputs, dict) else outputs[0]
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
return (loss, outputs) if return_outputs else loss
def is_local_process_zero(self) -> bool:
"""
......@@ -1718,29 +1720,27 @@ class Trainer:
ignore_keys = []
with torch.no_grad():
if self.use_amp:
with autocast():
outputs = model(**inputs)
else:
outputs = model(**inputs)
if has_labels:
if self.label_smoother is not None and "labels" in inputs:
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
else:
loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach()
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
else:
logits = outputs[1:]
else:
loss = None
if self.use_amp:
with autocast():
outputs = model(**inputs)
else:
outputs = model(**inputs)
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
else:
logits = outputs
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
self._past = outputs[self.args.past_index - 1]
if prediction_loss_only:
return (loss, None, None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment