Unverified Commit 0f226f78 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

push (#10846)

parent 82b8d8c7
...@@ -401,7 +401,7 @@ def evaluate(batch): ...@@ -401,7 +401,7 @@ def evaluate(batch):
with torch.no_grad(): with torch.no_grad():
logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
pred_ids = torch.argmax(logits, dim=-1) pred_ids = torch.argmax(logits, dim=-1)
batch["pred_strings"] = processor.batch_decode(pred_ids) batch["pred_strings"] = processor.batch_decode(pred_ids)
return batch return batch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment