Unverified Commit f71c9ccf authored by YQ's avatar YQ Committed by GitHub
Browse files

fix logit-to-multi-hot conversion in example (#26936)

* fix logit to multi-hot converstion

* add comments

* typo
parent 093848d3
...@@ -655,7 +655,7 @@ def main(): ...@@ -655,7 +655,7 @@ def main():
preds = np.squeeze(preds) preds = np.squeeze(preds)
result = metric.compute(predictions=preds, references=p.label_ids) result = metric.compute(predictions=preds, references=p.label_ids)
elif is_multi_label: elif is_multi_label:
preds = np.array([np.where(p > 0.5, 1, 0) for p in preds]) preds = np.array([np.where(p > 0, 1, 0) for p in preds]) # convert logits to multi-hot encoding
# Micro F1 is commonly used in multi-label classification # Micro F1 is commonly used in multi-label classification
result = metric.compute(predictions=preds, references=p.label_ids, average="micro") result = metric.compute(predictions=preds, references=p.label_ids, average="micro")
else: else:
...@@ -721,7 +721,10 @@ def main(): ...@@ -721,7 +721,10 @@ def main():
if is_regression: if is_regression:
predictions = np.squeeze(predictions) predictions = np.squeeze(predictions)
elif is_multi_label: elif is_multi_label:
predictions = np.array([np.where(p > 0.5, 1, 0) for p in predictions]) # Convert logits to multi-hot encoding. We compare the logits to 0 instead of 0.5, because the sigmoid is not applied.
# You can also pass `preprocess_logits_for_metrics=lambda logits, labels: nn.functional.sigmoid(logits)` to the Trainer
# and set p > 0.5 below (less efficient in this case)
predictions = np.array([np.where(p > 0, 1, 0) for p in predictions])
else: else:
predictions = np.argmax(predictions, axis=1) predictions = np.argmax(predictions, axis=1)
output_predict_file = os.path.join(training_args.output_dir, "predict_results.txt") output_predict_file = os.path.join(training_args.output_dir, "predict_results.txt")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment