"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "0afa5071bd84e44301750fdc594e33db102cf374"
Unverified Commit 900daec2 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing NER pipeline for list inputs. (#10184)

Fixes #10168
parent 587197dc
...@@ -28,11 +28,14 @@ class TokenClassificationArgumentHandler(ArgumentHandler): ...@@ -28,11 +28,14 @@ class TokenClassificationArgumentHandler(ArgumentHandler):
Handles arguments for token classification. Handles arguments for token classification.
""" """
def __call__(self, *args, **kwargs): def __call__(self, inputs: Union[str, List[str]], **kwargs):
if args is not None and len(args) > 0: if inputs is not None and isinstance(inputs, (list, tuple)) and len(inputs) > 0:
inputs = list(args) inputs = list(inputs)
batch_size = len(inputs) batch_size = len(inputs)
elif isinstance(inputs, str):
inputs = [inputs]
batch_size = 1
else: else:
raise ValueError("At least one input is required.") raise ValueError("At least one input is required.")
...@@ -137,11 +140,11 @@ class TokenClassificationPipeline(Pipeline): ...@@ -137,11 +140,11 @@ class TokenClassificationPipeline(Pipeline):
Only exists if the offsets are available within the tokenizer Only exists if the offsets are available within the tokenizer
""" """
inputs, offset_mappings = self._args_parser(inputs, **kwargs) _inputs, offset_mappings = self._args_parser(inputs, **kwargs)
answers = [] answers = []
for i, sentence in enumerate(inputs): for i, sentence in enumerate(_inputs):
# Manage correct placement of the tensors # Manage correct placement of the tensors
with self.device_placement(): with self.device_placement():
......
...@@ -14,14 +14,17 @@ ...@@ -14,14 +14,17 @@
import unittest import unittest
from transformers import AutoTokenizer, pipeline from transformers import AutoTokenizer, is_torch_available, pipeline
from transformers.pipelines import Pipeline, TokenClassificationArgumentHandler from transformers.pipelines import Pipeline, TokenClassificationArgumentHandler
from transformers.testing_utils import require_tf, require_torch, slow from transformers.testing_utils import require_tf, require_torch, slow
from .test_pipelines_common import CustomInputPipelineCommonMixin from .test_pipelines_common import CustomInputPipelineCommonMixin
VALID_INPUTS = ["A simple string", ["list of strings"]] if is_torch_available():
import numpy as np
VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]]
class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase): class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
...@@ -334,17 +337,26 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase): ...@@ -334,17 +337,26 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
@require_torch @require_torch
def test_simple(self): def test_simple(self):
nlp = pipeline(task="ner", model="dslim/bert-base-NER", grouped_entities=True) nlp = pipeline(task="ner", model="dslim/bert-base-NER", grouped_entities=True)
output = nlp("Hello Sarah Jessica Parker who Jessica lives in New York") sentence = "Hello Sarah Jessica Parker who Jessica lives in New York"
sentence2 = "This is a simple test"
output = nlp(sentence)
def simplify(output): def simplify(output):
for i in range(len(output)): if isinstance(output, (list, tuple)):
output[i]["score"] = round(output[i]["score"], 3) return [simplify(item) for item in output]
return output elif isinstance(output, dict):
return {simplify(k): simplify(v) for k, v in output.items()}
elif isinstance(output, (str, int, np.int64)):
return output
elif isinstance(output, float):
return round(output, 3)
else:
raise Exception(f"Cannot handle {type(output)}")
output = simplify(output) output_ = simplify(output)
self.assertEqual( self.assertEqual(
output, output_,
[ [
{ {
"entity_group": "PER", "entity_group": "PER",
...@@ -358,6 +370,21 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase): ...@@ -358,6 +370,21 @@ class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
], ],
) )
output = nlp([sentence, sentence2])
output_ = simplify(output)
self.assertEqual(
output_,
[
[
{"entity_group": "PER", "score": 0.996, "word": "Sarah Jessica Parker", "start": 6, "end": 26},
{"entity_group": "PER", "score": 0.977, "word": "Jessica", "start": 31, "end": 38},
{"entity_group": "LOC", "score": 0.999, "word": "New York", "start": 48, "end": 56},
],
[],
],
)
@require_torch @require_torch
def test_pt_small_ignore_subwords_available_for_fast_tokenizers(self): def test_pt_small_ignore_subwords_available_for_fast_tokenizers(self):
for model_name in self.small_models: for model_name in self.small_models:
...@@ -386,7 +413,7 @@ class TokenClassificationArgumentHandlerTestCase(unittest.TestCase): ...@@ -386,7 +413,7 @@ class TokenClassificationArgumentHandlerTestCase(unittest.TestCase):
self.assertEqual(inputs, [string]) self.assertEqual(inputs, [string])
self.assertEqual(offset_mapping, None) self.assertEqual(offset_mapping, None)
inputs, offset_mapping = self.args_parser(string, string) inputs, offset_mapping = self.args_parser([string, string])
self.assertEqual(inputs, [string, string]) self.assertEqual(inputs, [string, string])
self.assertEqual(offset_mapping, None) self.assertEqual(offset_mapping, None)
...@@ -394,25 +421,35 @@ class TokenClassificationArgumentHandlerTestCase(unittest.TestCase): ...@@ -394,25 +421,35 @@ class TokenClassificationArgumentHandlerTestCase(unittest.TestCase):
self.assertEqual(inputs, [string]) self.assertEqual(inputs, [string])
self.assertEqual(offset_mapping, [[(0, 1), (1, 2)]]) self.assertEqual(offset_mapping, [[(0, 1), (1, 2)]])
inputs, offset_mapping = self.args_parser(string, string, offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]]) inputs, offset_mapping = self.args_parser(
[string, string], offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]]
)
self.assertEqual(inputs, [string, string]) self.assertEqual(inputs, [string, string])
self.assertEqual(offset_mapping, [[(0, 1), (1, 2)], [(0, 2), (2, 3)]]) self.assertEqual(offset_mapping, [[(0, 1), (1, 2)], [(0, 2), (2, 3)]])
def test_errors(self): def test_errors(self):
string = "This is a simple input" string = "This is a simple input"
# 2 sentences, 1 offset_mapping # 2 sentences, 1 offset_mapping, args
with self.assertRaises(ValueError): with self.assertRaises(TypeError):
self.args_parser(string, string, offset_mapping=[[(0, 1), (1, 2)]]) self.args_parser(string, string, offset_mapping=[[(0, 1), (1, 2)]])
# 2 sentences, 1 offset_mapping # 2 sentences, 1 offset_mapping, args
with self.assertRaises(ValueError): with self.assertRaises(TypeError):
self.args_parser(string, string, offset_mapping=[(0, 1), (1, 2)]) self.args_parser(string, string, offset_mapping=[(0, 1), (1, 2)])
# 2 sentences, 1 offset_mapping, input_list
with self.assertRaises(ValueError):
self.args_parser([string, string], offset_mapping=[[(0, 1), (1, 2)]])
# 2 sentences, 1 offset_mapping, input_list
with self.assertRaises(ValueError):
self.args_parser([string, string], offset_mapping=[(0, 1), (1, 2)])
# 1 sentences, 2 offset_mapping # 1 sentences, 2 offset_mapping
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.args_parser(string, offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]]) self.args_parser(string, offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]])
# 0 sentences, 1 offset_mapping # 0 sentences, 1 offset_mapping
with self.assertRaises(ValueError): with self.assertRaises(TypeError):
self.args_parser(offset_mapping=[[(0, 1), (1, 2)]]) self.args_parser(offset_mapping=[[(0, 1), (1, 2)]])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment