Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
900daec2
Unverified
Commit
900daec2
authored
Feb 15, 2021
by
Nicolas Patry
Committed by
GitHub
Feb 15, 2021
Browse files
Fixing NER pipeline for list inputs. (#10184)
Fixes #10168
parent
587197dc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
20 deletions
+60
-20
src/transformers/pipelines/token_classification.py
src/transformers/pipelines/token_classification.py
+8
-5
tests/test_pipelines_ner.py
tests/test_pipelines_ner.py
+52
-15
No files found.
src/transformers/pipelines/token_classification.py
View file @
900daec2
...
...
@@ -28,11 +28,14 @@ class TokenClassificationArgumentHandler(ArgumentHandler):
Handles arguments for token classification.
"""
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
inputs
:
Union
[
str
,
List
[
str
]]
,
**
kwargs
):
if
arg
s
is
not
None
and
len
(
arg
s
)
>
0
:
inputs
=
list
(
arg
s
)
if
input
s
is
not
None
and
isinstance
(
inputs
,
(
list
,
tuple
))
and
len
(
input
s
)
>
0
:
inputs
=
list
(
input
s
)
batch_size
=
len
(
inputs
)
elif
isinstance
(
inputs
,
str
):
inputs
=
[
inputs
]
batch_size
=
1
else
:
raise
ValueError
(
"At least one input is required."
)
...
...
@@ -137,11 +140,11 @@ class TokenClassificationPipeline(Pipeline):
Only exists if the offsets are available within the tokenizer
"""
inputs
,
offset_mappings
=
self
.
_args_parser
(
inputs
,
**
kwargs
)
_
inputs
,
offset_mappings
=
self
.
_args_parser
(
inputs
,
**
kwargs
)
answers
=
[]
for
i
,
sentence
in
enumerate
(
inputs
):
for
i
,
sentence
in
enumerate
(
_
inputs
):
# Manage correct placement of the tensors
with
self
.
device_placement
():
...
...
tests/test_pipelines_ner.py
View file @
900daec2
...
...
@@ -14,14 +14,17 @@
import
unittest
from
transformers
import
AutoTokenizer
,
pipeline
from
transformers
import
AutoTokenizer
,
is_torch_available
,
pipeline
from
transformers.pipelines
import
Pipeline
,
TokenClassificationArgumentHandler
from
transformers.testing_utils
import
require_tf
,
require_torch
,
slow
from
.test_pipelines_common
import
CustomInputPipelineCommonMixin
VALID_INPUTS
=
[
"A simple string"
,
[
"list of strings"
]]
if
is_torch_available
():
import
numpy
as
np
VALID_INPUTS
=
[
"A simple string"
,
[
"list of strings"
,
"A simple string that is quite a bit longer"
]]
class
NerPipelineTests
(
CustomInputPipelineCommonMixin
,
unittest
.
TestCase
):
...
...
@@ -334,17 +337,26 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
@
require_torch
def
test_simple
(
self
):
nlp
=
pipeline
(
task
=
"ner"
,
model
=
"dslim/bert-base-NER"
,
grouped_entities
=
True
)
output
=
nlp
(
"Hello Sarah Jessica Parker who Jessica lives in New York"
)
sentence
=
"Hello Sarah Jessica Parker who Jessica lives in New York"
sentence2
=
"This is a simple test"
output
=
nlp
(
sentence
)
def
simplify
(
output
):
for
i
in
range
(
len
(
output
)):
output
[
i
][
"score"
]
=
round
(
output
[
i
][
"score"
],
3
)
return
output
if
isinstance
(
output
,
(
list
,
tuple
)):
return
[
simplify
(
item
)
for
item
in
output
]
elif
isinstance
(
output
,
dict
):
return
{
simplify
(
k
):
simplify
(
v
)
for
k
,
v
in
output
.
items
()}
elif
isinstance
(
output
,
(
str
,
int
,
np
.
int64
)):
return
output
elif
isinstance
(
output
,
float
):
return
round
(
output
,
3
)
else
:
raise
Exception
(
f
"Cannot handle
{
type
(
output
)
}
"
)
output
=
simplify
(
output
)
output
_
=
simplify
(
output
)
self
.
assertEqual
(
output
,
output
_
,
[
{
"entity_group"
:
"PER"
,
...
...
@@ -358,6 +370,21 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
],
)
output
=
nlp
([
sentence
,
sentence2
])
output_
=
simplify
(
output
)
self
.
assertEqual
(
output_
,
[
[
{
"entity_group"
:
"PER"
,
"score"
:
0.996
,
"word"
:
"Sarah Jessica Parker"
,
"start"
:
6
,
"end"
:
26
},
{
"entity_group"
:
"PER"
,
"score"
:
0.977
,
"word"
:
"Jessica"
,
"start"
:
31
,
"end"
:
38
},
{
"entity_group"
:
"LOC"
,
"score"
:
0.999
,
"word"
:
"New York"
,
"start"
:
48
,
"end"
:
56
},
],
[],
],
)
@
require_torch
def
test_pt_small_ignore_subwords_available_for_fast_tokenizers
(
self
):
for
model_name
in
self
.
small_models
:
...
...
@@ -386,7 +413,7 @@ class TokenClassificationArgumentHandlerTestCase(unittest.TestCase):
self
.
assertEqual
(
inputs
,
[
string
])
self
.
assertEqual
(
offset_mapping
,
None
)
inputs
,
offset_mapping
=
self
.
args_parser
(
string
,
string
)
inputs
,
offset_mapping
=
self
.
args_parser
(
[
string
,
string
]
)
self
.
assertEqual
(
inputs
,
[
string
,
string
])
self
.
assertEqual
(
offset_mapping
,
None
)
...
...
@@ -394,25 +421,35 @@ class TokenClassificationArgumentHandlerTestCase(unittest.TestCase):
self
.
assertEqual
(
inputs
,
[
string
])
self
.
assertEqual
(
offset_mapping
,
[[(
0
,
1
),
(
1
,
2
)]])
inputs
,
offset_mapping
=
self
.
args_parser
(
string
,
string
,
offset_mapping
=
[[(
0
,
1
),
(
1
,
2
)],
[(
0
,
2
),
(
2
,
3
)]])
inputs
,
offset_mapping
=
self
.
args_parser
(
[
string
,
string
],
offset_mapping
=
[[(
0
,
1
),
(
1
,
2
)],
[(
0
,
2
),
(
2
,
3
)]]
)
self
.
assertEqual
(
inputs
,
[
string
,
string
])
self
.
assertEqual
(
offset_mapping
,
[[(
0
,
1
),
(
1
,
2
)],
[(
0
,
2
),
(
2
,
3
)]])
def
test_errors
(
self
):
string
=
"This is a simple input"
# 2 sentences, 1 offset_mapping
with
self
.
assertRaises
(
Valu
eError
):
# 2 sentences, 1 offset_mapping
, args
with
self
.
assertRaises
(
Typ
eError
):
self
.
args_parser
(
string
,
string
,
offset_mapping
=
[[(
0
,
1
),
(
1
,
2
)]])
# 2 sentences, 1 offset_mapping
with
self
.
assertRaises
(
Valu
eError
):
# 2 sentences, 1 offset_mapping
, args
with
self
.
assertRaises
(
Typ
eError
):
self
.
args_parser
(
string
,
string
,
offset_mapping
=
[(
0
,
1
),
(
1
,
2
)])
# 2 sentences, 1 offset_mapping, input_list
with
self
.
assertRaises
(
ValueError
):
self
.
args_parser
([
string
,
string
],
offset_mapping
=
[[(
0
,
1
),
(
1
,
2
)]])
# 2 sentences, 1 offset_mapping, input_list
with
self
.
assertRaises
(
ValueError
):
self
.
args_parser
([
string
,
string
],
offset_mapping
=
[(
0
,
1
),
(
1
,
2
)])
# 1 sentences, 2 offset_mapping
with
self
.
assertRaises
(
ValueError
):
self
.
args_parser
(
string
,
offset_mapping
=
[[(
0
,
1
),
(
1
,
2
)],
[(
0
,
2
),
(
2
,
3
)]])
# 0 sentences, 1 offset_mapping
with
self
.
assertRaises
(
Valu
eError
):
with
self
.
assertRaises
(
Typ
eError
):
self
.
args_parser
(
offset_mapping
=
[[(
0
,
1
),
(
1
,
2
)]])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment