Unverified Commit 02b63702 authored by Ivan Sedykh's avatar Ivan Sedykh Committed by GitHub
Browse files

fix seq2seqtrainer predict without labels (#19721)

parent fac1f4b1
......@@ -208,9 +208,9 @@ class Seq2SeqTrainer(Trainer):
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1)
with torch.no_grad():
if has_labels:
with self.compute_loss_context_manager():
outputs = model(**inputs)
if has_labels:
if self.label_smoother is not None:
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment