Unverified Commit db514a75 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing backward compatiblity for non prefixed tokens (B-, I-). (#13493)

parent e59d4d01
......@@ -411,7 +411,8 @@ class TokenClassificationPipeline(Pipeline):
tag = entity_name[2:]
else:
# It's not in B-, I- format
bi = "B"
# Default to I- for continuation.
bi = "I"
tag = entity_name
return bi, tag
......
......@@ -318,6 +318,59 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
],
)
@require_torch
def test_aggregation_strategy_no_b_i_prefix(self):
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt")
# Just to understand scores indexes in this test
token_classifier.model.config.id2label = {0: "O", 1: "MISC", 2: "PER", 3: "ORG", 4: "LOC"}
example = [
{
# fmt : off
"scores": np.array([0, 0, 0, 0, 0.9968166351318359]),
"index": 1,
"is_subword": False,
"word": "En",
"start": 0,
"end": 2,
},
{
# fmt : off
"scores": np.array([0, 0, 0, 0, 0.9957635998725891]),
"index": 2,
"is_subword": True,
"word": "##zo",
"start": 2,
"end": 4,
},
{
# fmt: off
"scores": np.array([0, 0, 0, 0.9986497163772583, 0]),
# fmt: on
"index": 7,
"word": "UN",
"is_subword": False,
"start": 11,
"end": 13,
},
]
self.assertEqual(
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.NONE)),
[
{"end": 2, "entity": "LOC", "score": 0.997, "start": 0, "word": "En", "index": 1},
{"end": 4, "entity": "LOC", "score": 0.996, "start": 2, "word": "##zo", "index": 2},
{"end": 13, "entity": "ORG", "score": 0.999, "start": 11, "word": "UN", "index": 7},
],
)
self.assertEqual(
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.SIMPLE)),
[
{"entity_group": "LOC", "score": 0.996, "word": "Enzo", "start": 0, "end": 4},
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 11, "end": 13},
],
)
@require_torch
def test_aggregation_strategy(self):
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment