Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f71c9ccf
Unverified
Commit
f71c9ccf
authored
Oct 23, 2023
by
YQ
Committed by
GitHub
Oct 23, 2023
Browse files
fix logit-to-multi-hot conversion in example (#26936)
* fix logit to multi-hot converstion * add comments * typo
parent
093848d3
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
2 deletions
+5
-2
examples/pytorch/text-classification/run_classification.py
examples/pytorch/text-classification/run_classification.py
+5
-2
No files found.
examples/pytorch/text-classification/run_classification.py
View file @
f71c9ccf
...
...
@@ -655,7 +655,7 @@ def main():
preds
=
np
.
squeeze
(
preds
)
result
=
metric
.
compute
(
predictions
=
preds
,
references
=
p
.
label_ids
)
elif
is_multi_label
:
preds
=
np
.
array
([
np
.
where
(
p
>
0
.5
,
1
,
0
)
for
p
in
preds
])
preds
=
np
.
array
([
np
.
where
(
p
>
0
,
1
,
0
)
for
p
in
preds
])
# convert logits to multi-hot encoding
# Micro F1 is commonly used in multi-label classification
result
=
metric
.
compute
(
predictions
=
preds
,
references
=
p
.
label_ids
,
average
=
"micro"
)
else
:
...
...
@@ -721,7 +721,10 @@ def main():
if
is_regression
:
predictions
=
np
.
squeeze
(
predictions
)
elif
is_multi_label
:
predictions
=
np
.
array
([
np
.
where
(
p
>
0.5
,
1
,
0
)
for
p
in
predictions
])
# Convert logits to multi-hot encoding. We compare the logits to 0 instead of 0.5, because the sigmoid is not applied.
# You can also pass `preprocess_logits_for_metrics=lambda logits, labels: nn.functional.sigmoid(logits)` to the Trainer
# and set p > 0.5 below (less efficient in this case)
predictions
=
np
.
array
([
np
.
where
(
p
>
0
,
1
,
0
)
for
p
in
predictions
])
else
:
predictions
=
np
.
argmax
(
predictions
,
axis
=
1
)
output_predict_file
=
os
.
path
.
join
(
training_args
.
output_dir
,
"predict_results.txt"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment