Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bcd8dc6b
Commit
bcd8dc6b
authored
Nov 05, 2019
by
VictorSanh
Committed by
Lysandre Debut
Nov 27, 2019
Browse files
move xnli_compute_metrics to data/metrics
parent
73fe2e73
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
10 additions
and
9 deletions
+10
-9
transformers/__init__.py
transformers/__init__.py
+1
-1
transformers/data/__init__.py
transformers/data/__init__.py
+1
-1
transformers/data/metrics/__init__.py
transformers/data/metrics/__init__.py
+8
-0
transformers/data/processors/xnli.py
transformers/data/processors/xnli.py
+0
-7
No files found.
transformers/__init__.py
View file @
bcd8dc6b
...
...
@@ -29,7 +29,7 @@ from .data import (is_sklearn_available,
xnli_output_modes
,
xnli_processors
,
xnli_tasks_num_labels
)
if
is_sklearn_available
():
from
.data
import
glue_compute_metrics
from
.data
import
glue_compute_metrics
,
xnli_compute_metrics
# Tokenizers
from
.tokenization_utils
import
(
PreTrainedTokenizer
)
...
...
transformers/data/__init__.py
View file @
bcd8dc6b
...
...
@@ -4,4 +4,4 @@ from .processors import xnli_output_modes, xnli_processors, xnli_tasks_num_label
from
.metrics
import
is_sklearn_available
if
is_sklearn_available
():
from
.metrics
import
glue_compute_metrics
from
.metrics
import
glue_compute_metrics
,
xnli_compute_metrics
transformers/data/metrics/__init__.py
View file @
bcd8dc6b
...
...
@@ -81,3 +81,11 @@ if _has_sklearn:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
else
:
raise
KeyError
(
task_name
)
def
xnli_compute_metrics
(
task_name
,
preds
,
labels
):
assert
len
(
preds
)
==
len
(
labels
)
if
task_name
==
"xnli"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
else
:
raise
KeyError
(
task_name
)
transformers/data/processors/xnli.py
View file @
bcd8dc6b
...
...
@@ -73,13 +73,6 @@ class XnliProcessor(DataProcessor):
"""See base class."""
return
[
"contradiction"
,
"entailment"
,
"neutral"
]
def
xnli_compute_metrics
(
task_name
,
preds
,
labels
):
assert
len
(
preds
)
==
len
(
labels
)
if
task_name
==
"xnli"
:
return
{
"acc"
:
simple_accuracy
(
preds
,
labels
)}
else
:
raise
ValueError
(
'{} is not a supported task.'
.
format
(
task_name
))
xnli_processors
=
{
"xnli"
:
XnliProcessor
,
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment