Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
cca75e78
Commit
cca75e78
authored
Dec 04, 2019
by
LysandreJik
Browse files
Kill the demon spawn
parent
bf119c05
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
64 deletions
+34
-64
examples/run_squad.py
examples/run_squad.py
+22
-1
transformers/data/processors/squad.py
transformers/data/processors/squad.py
+12
-63
No files found.
examples/run_squad.py
View file @
cca75e78
...
@@ -248,7 +248,28 @@ def evaluate(args, model, tokenizer, prefix=""):
...
@@ -248,7 +248,28 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_feature
=
features
[
example_index
.
item
()]
eval_feature
=
features
[
example_index
.
item
()]
unique_id
=
int
(
eval_feature
.
unique_id
)
unique_id
=
int
(
eval_feature
.
unique_id
)
result
=
SquadResult
([
to_list
(
output
[
i
])
for
output
in
outputs
]
+
[
unique_id
])
output
=
[
to_list
(
output
[
i
])
for
output
in
outputs
]
if
len
(
output
)
>=
5
:
start_logits
=
output
[
0
]
start_top_index
=
output
[
1
]
end_logits
=
output
[
2
]
end_top_index
=
output
[
3
],
cls_logits
=
output
[
4
]
result
=
SquadResult
(
unique_id
,
start_logits
,
end_logits
,
start_top_index
=
start_top_index
,
end_top_index
=
end_top_index
,
cls_logits
=
cls_logits
)
else
:
start_logits
,
end_logits
=
output
result
=
SquadResult
(
unique_id
,
start_logits
,
end_logits
)
all_results
.
append
(
result
)
all_results
.
append
(
result
)
evalTime
=
timeit
.
default_timer
()
-
start_time
evalTime
=
timeit
.
default_timer
()
-
start_time
...
...
transformers/data/processors/squad.py
View file @
cca75e78
...
@@ -446,72 +446,21 @@ class SquadFeatures(object):
...
@@ -446,72 +446,21 @@ class SquadFeatures(object):
self
.
end_position
=
end_position
self
.
end_position
=
end_position
class
SquadResult
(
object
):
class
SquadResult
(
object
):
"""
"""
Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset.
Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset.
Args:
Args:
result: The result output by a model on a SQuAD inference. These results may be complex (5 values) as the ones output by
unique_id: The unique identifier corresponding to that example.
XLNet or XLM or may be simple like the other models (2 values). They may be passed as a list or as a dict, with the
start_logits: The logits corresponding to the start of the answer
following accepted formats:
end_logits: The logits corresponding to the end of the answer
`dict` output by a simple model:
{
"start_logits": int,
"end_logits": int,
"unique_id": string
}
`list` output by a simple model:
[start_logits, end_logits, unique_id]
`dict` output by a complex model:
{
"start_top_log_probs": float,
"start_top_index": int,
"end_top_log_probs": float,
"end_top_index": int,
"cls_logits": int,
"unique_id": string
}
`list` output by a complex model:
[start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, unique_id]
See `run_squad.py` for an example.
"""
"""
def
__init__
(
self
,
result
):
def
__init__
(
self
,
unique_id
,
start_logits
,
end_logits
,
start_top_index
=
None
,
end_top_index
=
None
,
cls_logits
=
None
):
if
isinstance
(
result
,
dict
):
self
.
start_top_log_probs
=
start_logits
if
"start_logits"
in
result
and
"end_logits"
in
result
:
self
.
end_top_log_probs
=
end_logits
self
.
start_logits
=
result
[
"start_logits"
]
self
.
unique_id
=
unique_id
self
.
end_logits
=
result
[
"end_logits"
]
elif
"start_top_log_probs"
in
result
and
"start_top_index"
in
result
:
self
.
start_top_log_probs
=
result
[
"start_top_log_probs"
]
self
.
start_top_index
=
result
[
"start_top_index"
]
self
.
end_top_log_probs
=
result
[
"end_top_log_probs"
]
self
.
end_top_index
=
result
[
"end_top_index"
]
self
.
cls_logits
=
result
[
"cls_logits"
]
else
:
raise
ValueError
(
"SquadResult instantiated with wrong values."
)
self
.
unique_id
=
result
[
"unique_id"
]
elif
isinstance
(
result
,
list
):
if
len
(
result
)
==
3
:
self
.
start_logits
=
result
[
0
]
self
.
end_logits
=
result
[
1
]
elif
len
(
result
)
==
6
:
self
.
start_top_log_probs
=
result
[
0
]
self
.
start_top_index
=
result
[
1
]
self
.
end_top_log_probs
=
result
[
2
]
self
.
end_top_index
=
result
[
3
]
self
.
cls_logits
=
result
[
4
]
else
:
raise
ValueError
(
"SquadResult instantiated with wrong values."
)
self
.
unique_id
=
result
[
-
1
]
else
:
if
start_top_index
:
raise
ValueError
(
"SquadResult instantiated with wrong values. Should be a dictionary or a list."
)
self
.
start_top_index
=
start_top_index
self
.
end_top_index
=
end_top_index
self
.
cls_logits
=
cls_logits
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment