"vscode:/vscode.git/clone" did not exist on "bdf31d6e0a282cb902ad911354c7ba7f042ed945"
Unverified Commit c0281feb authored by davidleonfdez's avatar davidleonfdez Committed by GitHub
Browse files

Fix #15898 (#15928)

parent 9251427c
......@@ -454,6 +454,10 @@ def main():
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
# Depending on the model and config, logits may contain extra tensors,
# like past_key_values, but logits always come first
logits = logits[0]
return logits.argmax(dim=-1)
metric = load_metric("accuracy")
......
......@@ -477,6 +477,10 @@ def main():
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
# Depending on the model and config, logits may contain extra tensors,
# like past_key_values, but logits always come first
logits = logits[0]
return logits.argmax(dim=-1)
metric = load_metric("accuracy")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment