test_pipelines_ner.py 4.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import unittest

from transformers import pipeline
from transformers.pipelines import Pipeline
from transformers.testing_utils import require_tf

from .test_pipelines_common import CustomInputPipelineCommonMixin


VALID_INPUTS = ["A simple string", ["list of strings"]]


class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
    pipeline_task = "ner"
    small_models = [
        "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
    ]  # Default model - Models tested without the @slow decorator
    large_models = []  # Models tested with the @slow decorator

    def _test_pipeline(self, nlp: Pipeline):
        output_keys = {"entity", "word", "score"}

        ungrouped_ner_inputs = [
            [
                {"entity": "B-PER", "index": 1, "score": 0.9994944930076599, "word": "Cons"},
                {"entity": "B-PER", "index": 2, "score": 0.8025449514389038, "word": "##uelo"},
                {"entity": "I-PER", "index": 3, "score": 0.9993102550506592, "word": "Ara"},
                {"entity": "I-PER", "index": 4, "score": 0.9993743896484375, "word": "##煤j"},
                {"entity": "I-PER", "index": 5, "score": 0.9992871880531311, "word": "##o"},
                {"entity": "I-PER", "index": 6, "score": 0.9993029236793518, "word": "No"},
                {"entity": "I-PER", "index": 7, "score": 0.9981776475906372, "word": "##guera"},
                {"entity": "B-PER", "index": 15, "score": 0.9998136162757874, "word": "Andr茅s"},
                {"entity": "I-PER", "index": 16, "score": 0.999740719795227, "word": "Pas"},
                {"entity": "I-PER", "index": 17, "score": 0.9997414350509644, "word": "##tran"},
                {"entity": "I-PER", "index": 18, "score": 0.9996136426925659, "word": "##a"},
                {"entity": "B-ORG", "index": 28, "score": 0.9989739060401917, "word": "Far"},
                {"entity": "I-ORG", "index": 29, "score": 0.7188422083854675, "word": "##c"},
            ],
            [
                {"entity": "I-PER", "index": 1, "score": 0.9968166351318359, "word": "En"},
                {"entity": "I-PER", "index": 2, "score": 0.9957635998725891, "word": "##zo"},
                {"entity": "I-ORG", "index": 7, "score": 0.9986497163772583, "word": "UN"},
            ],
        ]
        expected_grouped_ner_results = [
            [
                {"entity_group": "B-PER", "score": 0.9710702640669686, "word": "Consuelo Ara煤jo Noguera"},
                {"entity_group": "B-PER", "score": 0.9997273534536362, "word": "Andr茅s Pastrana"},
                {"entity_group": "B-ORG", "score": 0.8589080572128296, "word": "Farc"},
            ],
            [
                {"entity_group": "I-PER", "score": 0.9962901175022125, "word": "Enzo"},
                {"entity_group": "I-ORG", "score": 0.9986497163772583, "word": "UN"},
            ],
        ]

        self.assertIsNotNone(nlp)

        mono_result = nlp(VALID_INPUTS[0])
        self.assertIsInstance(mono_result, list)
        self.assertIsInstance(mono_result[0], (dict, list))

        if isinstance(mono_result[0], list):
            mono_result = mono_result[0]

        for key in output_keys:
            self.assertIn(key, mono_result[0])

        multi_result = [nlp(input) for input in VALID_INPUTS]
        self.assertIsInstance(multi_result, list)
        self.assertIsInstance(multi_result[0], (dict, list))

        if isinstance(multi_result[0], list):
            multi_result = multi_result[0]

        for result in multi_result:
            for key in output_keys:
                self.assertIn(key, result)

        for ungrouped_input, grouped_result in zip(ungrouped_ner_inputs, expected_grouped_ner_results):
            self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result)

    @require_tf
    def test_tf_only(self):
        model_name = "Narsil/small"  # This model only has a TensorFlow version
        # We test that if we don't specificy framework='tf', it gets detected automatically
        nlp = pipeline(task="ner", model=model_name, tokenizer=model_name)
        self._test_pipeline(nlp)