"vscode:/vscode.git/clone" did not exist on "d996024af7a94b0f11a5ad351217b648ecaed72a"
Unverified Commit a3bd7637 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Better heuristic for token-classification pipeline. (#12611)

* Better heuristic for token-classification pipeline.

Relooking at the problem makes thing actually much simpler,
when we look at ids from a tokenizer, we have no way in **general**
to recover if some substring is part of a word or not.

However, within the pipeline, with offsets we still have access to the
original string, so we can simply look if previous character (if it
exists) of a token, is actually a space. This will obviously be wrong
for tokenizers that contain spaces within tokens, tokenizers where
offsets include spaces too (Don't think there are a lot).

This heuristic hopefully is fully bc and still can handle non-word based
tokenizers.

* Updating test with real values.

* We still need the older "correct" heuristic to prevent fusing
punctuation.

* Adding a real warning when important.
parent 569f61a7
......@@ -270,7 +270,19 @@ class TokenClassificationPipeline(Pipeline):
if offset_mapping is not None:
start_ind, end_ind = offset_mapping[idx]
word_ref = sentence[start_ind:end_ind]
is_subword = len(word_ref) != len(word)
if getattr(self.tokenizer._tokenizer.model, "continuing_subword_prefix", None):
# This is a BPE, word aware tokenizer, there is a correct way
# to fuse tokens
is_subword = len(word) != len(word_ref)
else:
# This is a fallback heuristic. This will fail most likely on any kind of text + punctuation mixtures that will be considered "words". Non word aware models cannot do better than this unfortunately.
if self.aggregation_strategy in {
AggregationStrategy.FIRST,
AggregationStrategy.AVERAGE,
AggregationStrategy.MAX,
}:
warnings.warn(UserWarning, "Tokenizer does not support real words, using fallback heuristic")
is_subword = sentence[start_ind - 1 : start_ind] != " " if start_ind > 0 else False
if int(input_ids[idx]) == self.tokenizer.unk_token_id:
word = word_ref
......
......@@ -191,6 +191,19 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
],
)
@require_torch
@slow
def test_aggregation_strategy_byte_level_tokenizer(self):
sentence = "Groenlinks praat over Schiphol."
ner = pipeline("ner", model="xlm-roberta-large-finetuned-conll02-dutch", aggregation_strategy="max")
self.assertEqual(
nested_simplify(ner(sentence)),
[
{"end": 10, "entity_group": "ORG", "score": 0.994, "start": 0, "word": "Groenlinks"},
{"entity_group": "LOC", "score": 1.0, "word": "Schiphol.", "start": 22, "end": 31},
],
)
@require_torch
def test_aggregation_strategy(self):
model_name = self.small_models[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment