Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
900daec2
"src/vscode:/vscode.git/clone" did not exist on "5d49b3e83b97b45d8745ed6fc9f06c32d5ef9286"
Unverified
Commit
900daec2
authored
Feb 15, 2021
by
Nicolas Patry
Committed by
GitHub
Feb 15, 2021
Browse files
Fixing NER pipeline for list inputs. (#10184)
Fixes #10168
parent
587197dc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
20 deletions
+60
-20
src/transformers/pipelines/token_classification.py
src/transformers/pipelines/token_classification.py
+8
-5
tests/test_pipelines_ner.py
tests/test_pipelines_ner.py
+52
-15
No files found.
src/transformers/pipelines/token_classification.py
View file @
900daec2
...
...
@@ -28,11 +28,14 @@ class TokenClassificationArgumentHandler(ArgumentHandler):
Handles arguments for token classification.
"""
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
inputs
:
Union
[
str
,
List
[
str
]]
,
**
kwargs
):
if
arg
s
is
not
None
and
len
(
arg
s
)
>
0
:
inputs
=
list
(
arg
s
)
if
input
s
is
not
None
and
isinstance
(
inputs
,
(
list
,
tuple
))
and
len
(
input
s
)
>
0
:
inputs
=
list
(
input
s
)
batch_size
=
len
(
inputs
)
elif
isinstance
(
inputs
,
str
):
inputs
=
[
inputs
]
batch_size
=
1
else
:
raise
ValueError
(
"At least one input is required."
)
...
...
@@ -137,11 +140,11 @@ class TokenClassificationPipeline(Pipeline):
Only exists if the offsets are available within the tokenizer
"""
inputs
,
offset_mappings
=
self
.
_args_parser
(
inputs
,
**
kwargs
)
_
inputs
,
offset_mappings
=
self
.
_args_parser
(
inputs
,
**
kwargs
)
answers
=
[]
for
i
,
sentence
in
enumerate
(
inputs
):
for
i
,
sentence
in
enumerate
(
_
inputs
):
# Manage correct placement of the tensors
with
self
.
device_placement
():
...
...
tests/test_pipelines_ner.py
View file @
900daec2
...
...
@@ -14,14 +14,17 @@
import
unittest
from
transformers
import
AutoTokenizer
,
pipeline
from
transformers
import
AutoTokenizer
,
is_torch_available
,
pipeline
from
transformers.pipelines
import
Pipeline
,
TokenClassificationArgumentHandler
from
transformers.testing_utils
import
require_tf
,
require_torch
,
slow
from
.test_pipelines_common
import
CustomInputPipelineCommonMixin
VALID_INPUTS
=
[
"A simple string"
,
[
"list of strings"
]]
if
is_torch_available
():
import
numpy
as
np
VALID_INPUTS
=
[
"A simple string"
,
[
"list of strings"
,
"A simple string that is quite a bit longer"
]]
class
NerPipelineTests
(
CustomInputPipelineCommonMixin
,
unittest
.
TestCase
):
...
...
@@ -334,17 +337,26 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
@
require_torch
def
test_simple
(
self
):
nlp
=
pipeline
(
task
=
"ner"
,
model
=
"dslim/bert-base-NER"
,
grouped_entities
=
True
)
output
=
nlp
(
"Hello Sarah Jessica Parker who Jessica lives in New York"
)
sentence
=
"Hello Sarah Jessica Parker who Jessica lives in New York"
sentence2
=
"This is a simple test"
output
=
nlp
(
sentence
)
def
simplify
(
output
):
for
i
in
range
(
len
(
output
)):
output
[
i
][
"score"
]
=
round
(
output
[
i
][
"score"
],
3
)
return
output
if
isinstance
(
output
,
(
list
,
tuple
)):
return
[
simplify
(
item
)
for
item
in
output
]
elif
isinstance
(
output
,
dict
):
return
{
simplify
(
k
):
simplify
(
v
)
for
k
,
v
in
output
.
items
()}
elif
isinstance
(
output
,
(
str
,
int
,
np
.
int64
)):
return
output
elif
isinstance
(
output
,
float
):
return
round
(
output
,
3
)
else
:
raise
Exception
(
f
"Cannot handle
{
type
(
output
)
}
"
)
output
=
simplify
(
output
)
output
_
=
simplify
(
output
)
self
.
assertEqual
(
output
,
output
_
,
[
{
"entity_group"
:
"PER"
,
...
...
@@ -358,6 +370,21 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
],
)
output
=
nlp
([
sentence
,
sentence2
])
output_
=
simplify
(
output
)
self
.
assertEqual
(
output_
,
[
[
{
"entity_group"
:
"PER"
,
"score"
:
0.996
,
"word"
:
"Sarah Jessica Parker"
,
"start"
:
6
,
"end"
:
26
},
{
"entity_group"
:
"PER"
,
"score"
:
0.977
,
"word"
:
"Jessica"
,
"start"
:
31
,
"end"
:
38
},
{
"entity_group"
:
"LOC"
,
"score"
:
0.999
,
"word"
:
"New York"
,
"start"
:
48
,
"end"
:
56
},
],
[],
],
)
@
require_torch
def
test_pt_small_ignore_subwords_available_for_fast_tokenizers
(
self
):
for
model_name
in
self
.
small_models
:
...
...
@@ -386,7 +413,7 @@ class TokenClassificationArgumentHandlerTestCase(unittest.TestCase):
self
.
assertEqual
(
inputs
,
[
string
])
self
.
assertEqual
(
offset_mapping
,
None
)
inputs
,
offset_mapping
=
self
.
args_parser
(
string
,
string
)
inputs
,
offset_mapping
=
self
.
args_parser
(
[
string
,
string
]
)
self
.
assertEqual
(
inputs
,
[
string
,
string
])
self
.
assertEqual
(
offset_mapping
,
None
)
...
...
@@ -394,25 +421,35 @@ class TokenClassificationArgumentHandlerTestCase(unittest.TestCase):
self
.
assertEqual
(
inputs
,
[
string
])
self
.
assertEqual
(
offset_mapping
,
[[(
0
,
1
),
(
1
,
2
)]])
inputs
,
offset_mapping
=
self
.
args_parser
(
string
,
string
,
offset_mapping
=
[[(
0
,
1
),
(
1
,
2
)],
[(
0
,
2
),
(
2
,
3
)]])
inputs
,
offset_mapping
=
self
.
args_parser
(
[
string
,
string
],
offset_mapping
=
[[(
0
,
1
),
(
1
,
2
)],
[(
0
,
2
),
(
2
,
3
)]]
)
self
.
assertEqual
(
inputs
,
[
string
,
string
])
self
.
assertEqual
(
offset_mapping
,
[[(
0
,
1
),
(
1
,
2
)],
[(
0
,
2
),
(
2
,
3
)]])
def
test_errors
(
self
):
string
=
"This is a simple input"
# 2 sentences, 1 offset_mapping
with
self
.
assertRaises
(
Valu
eError
):
# 2 sentences, 1 offset_mapping
, args
with
self
.
assertRaises
(
Typ
eError
):
self
.
args_parser
(
string
,
string
,
offset_mapping
=
[[(
0
,
1
),
(
1
,
2
)]])
# 2 sentences, 1 offset_mapping
with
self
.
assertRaises
(
Valu
eError
):
# 2 sentences, 1 offset_mapping
, args
with
self
.
assertRaises
(
Typ
eError
):
self
.
args_parser
(
string
,
string
,
offset_mapping
=
[(
0
,
1
),
(
1
,
2
)])
# 2 sentences, 1 offset_mapping, input_list
with
self
.
assertRaises
(
ValueError
):
self
.
args_parser
([
string
,
string
],
offset_mapping
=
[[(
0
,
1
),
(
1
,
2
)]])
# 2 sentences, 1 offset_mapping, input_list
with
self
.
assertRaises
(
ValueError
):
self
.
args_parser
([
string
,
string
],
offset_mapping
=
[(
0
,
1
),
(
1
,
2
)])
# 1 sentences, 2 offset_mapping
with
self
.
assertRaises
(
ValueError
):
self
.
args_parser
(
string
,
offset_mapping
=
[[(
0
,
1
),
(
1
,
2
)],
[(
0
,
2
),
(
2
,
3
)]])
# 0 sentences, 1 offset_mapping
with
self
.
assertRaises
(
Valu
eError
):
with
self
.
assertRaises
(
Typ
eError
):
self
.
args_parser
(
offset_mapping
=
[[(
0
,
1
),
(
1
,
2
)]])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment