Unverified Commit c0281feb authored by davidleonfdez's avatar davidleonfdez Committed by GitHub
Browse files

Fix #15898 (#15928)

parent 9251427c
...@@ -454,6 +454,10 @@ def main(): ...@@ -454,6 +454,10 @@ def main():
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
def preprocess_logits_for_metrics(logits, labels): def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
# Depending on the model and config, logits may contain extra tensors,
# like past_key_values, but logits always come first
logits = logits[0]
return logits.argmax(dim=-1) return logits.argmax(dim=-1)
metric = load_metric("accuracy") metric = load_metric("accuracy")
......
...@@ -477,6 +477,10 @@ def main(): ...@@ -477,6 +477,10 @@ def main():
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
def preprocess_logits_for_metrics(logits, labels): def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
# Depending on the model and config, logits may contain extra tensors,
# like past_key_values, but logits always come first
logits = logits[0]
return logits.argmax(dim=-1) return logits.argmax(dim=-1)
metric = load_metric("accuracy") metric = load_metric("accuracy")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment