Unverified Commit b88e0e01 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

[TokenClassification] Label realignment for subword aggregation (#11680)

* [TokenClassification] Label realignment for subword aggregation

Tentative to replace https://github.com/huggingface/transformers/pull/11622/files



- Added `AggregationStrategy`
- `ignore_subwords` and `grouped_entities` arguments are now fused
  into `aggregation_strategy`. It makes more sense anyway because
  `ignore_subwords=True` with `grouped_entities=False` did not have a
  meaning anyway.
- Added 2 new ways to aggregate which are MAX, and AVERAGE
- AVERAGE requires a bit more information than the others, for now this
case is slightly specific, we should keep that in mind for future
changes.
- Testing has been modified to reflect new argument, and to check the
correct deprecation and the new aggregation_strategy.
- Put the testing argument and testing results for aggregation_strategy,
close together, so that readers can understand what is supposed to
happen.
- `aggregate` is now only tested on a small model as it does not mean
anything to test it globally for all models.
- Previous tests are unchanged in desired output.
- Added a new test case that showcases better the difference between the
  FIRST, MAX and AVERAGE strategies.

* Wrong framework.

* Addressing three issues.

1- Tags might not follow B-, I- convention, so any tag should work now
(assumed as B-TAG)
2- Fixed an issue with average that leads to a substantial code change.
3- The testing suite was not checking for the "index" key for "none"
strategy. This is now fixed.

The issue is that "O" could not be chosen by AVERAGE strategy because
those tokens were filtered out beforehand, so their relative scores were
not counted in the average. Now filtering on
ignore_labels will happen at the very end of the pipeline fixing
that issue.
It's a bit hard to make sure this stays like that because we do
not have a end-to-end test for that behavior

* Formatting.

* Adding formatting to code + cleaner handling of B-, I- tags.
Co-authored-by: default avatarFrancesco Rubbo <rubbo.francesco@gmail.com>
Co-authored-by: default avatarelk-cloner <rezakakhki.rk@gmail.com>

* Typo.
Co-authored-by: default avatarFrancesco Rubbo <rubbo.francesco@gmail.com>
Co-authored-by: default avatarelk-cloner <rezakakhki.rk@gmail.com>
parent c73e3532
...@@ -48,7 +48,12 @@ from .table_question_answering import TableQuestionAnsweringArgumentHandler, Tab ...@@ -48,7 +48,12 @@ from .table_question_answering import TableQuestionAnsweringArgumentHandler, Tab
from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline from .text2text_generation import SummarizationPipeline, Text2TextGenerationPipeline, TranslationPipeline
from .text_classification import TextClassificationPipeline from .text_classification import TextClassificationPipeline
from .text_generation import TextGenerationPipeline from .text_generation import TextGenerationPipeline
from .token_classification import NerPipeline, TokenClassificationArgumentHandler, TokenClassificationPipeline from .token_classification import (
AggregationStrategy,
NerPipeline,
TokenClassificationArgumentHandler,
TokenClassificationPipeline,
)
from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline
......
from typing import TYPE_CHECKING, List, Optional, Union import warnings
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import numpy as np import numpy as np
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available from ..file_utils import ExplicitEnum, add_end_docstrings, is_tf_available, is_torch_available
from ..modelcard import ModelCard from ..modelcard import ModelCard
from ..models.bert.tokenization_bert import BasicTokenizer from ..models.bert.tokenization_bert import BasicTokenizer
from ..tokenization_utils import PreTrainedTokenizer from ..tokenization_utils import PreTrainedTokenizer
...@@ -48,13 +49,43 @@ class TokenClassificationArgumentHandler(ArgumentHandler): ...@@ -48,13 +49,43 @@ class TokenClassificationArgumentHandler(ArgumentHandler):
return inputs, offset_mapping return inputs, offset_mapping
class AggregationStrategy(ExplicitEnum):
"""All the valid aggregation strategies for TokenClassificationPipeline"""
NONE = "none"
SIMPLE = "simple"
FIRST = "first"
AVERAGE = "average"
MAX = "max"
@add_end_docstrings( @add_end_docstrings(
PIPELINE_INIT_ARGS, PIPELINE_INIT_ARGS,
r""" r"""
ignore_labels (:obj:`List[str]`, defaults to :obj:`["O"]`): ignore_labels (:obj:`List[str]`, defaults to :obj:`["O"]`):
A list of labels to ignore. A list of labels to ignore.
grouped_entities (:obj:`bool`, `optional`, defaults to :obj:`False`): grouped_entities (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to group the tokens corresponding to the same entity together in the predictions or not. DEPRECATED, use :obj:`aggregation_strategy` instead. Whether or not to group the tokens corresponding to
the same entity together in the predictions or not.
aggregation_strategy (:obj:`str`, `optional`, defaults to :obj:`"none"`): The strategy to fuse (or not) tokens based on the model prediction.
- "none" : Will simply not do any aggregation and simply return raw results from the model
- "simple" : Will attempt to group entities following the default schema. (A, B-TAG), (B, I-TAG), (C,
I-TAG), (D, B-TAG2) (E, B-TAG2) will end up being [{"word": ABC, "entity": "TAG"}, {"word": "D",
"entity": "TAG2"}, {"word": "E", "entity": "TAG2"}] Notice that two consecutive B tags will end up as
different entities. On word based languages, we might end up splitting words undesirably : Imagine
Microsoft being tagged as [{"word": "Micro", "entity": "ENTERPRISE"}, {"word": "soft", "entity":
"NAME"}]. Look for FIRST, MAX, AVERAGE for ways to mitigate that and disambiguate words (on languages
that support that meaning, which is basically tokens separated by a space). These mitigations will
only work on real words, "New york" might still be tagged with two different entities.
- "first" : (works only on word based models) Will use the :obj:`SIMPLE` strategy except that words,
cannot end up with different tags. Words will simply use the tag of the first token of the word when
there is ambiguity.
- "average" : (works only on word based models) Will use the :obj:`SIMPLE` strategy except that words,
cannot end up with different tags. scores will be averaged first across tokens, and then the maximum
label is applied.
- "max" : (works only on word based models) Will use the :obj:`SIMPLE` strategy except that words,
cannot end up with different tags. Word entity will simply be the token with the maximum score.
""", """,
) )
class TokenClassificationPipeline(Pipeline): class TokenClassificationPipeline(Pipeline):
...@@ -84,8 +115,9 @@ class TokenClassificationPipeline(Pipeline): ...@@ -84,8 +115,9 @@ class TokenClassificationPipeline(Pipeline):
binary_output: bool = False, binary_output: bool = False,
ignore_labels=["O"], ignore_labels=["O"],
task: str = "", task: str = "",
grouped_entities: bool = False, grouped_entities: Optional[bool] = None,
ignore_subwords: bool = False, ignore_subwords: Optional[bool] = None,
aggregation_strategy: Optional[AggregationStrategy] = None,
): ):
super().__init__( super().__init__(
model=model, model=model,
...@@ -106,15 +138,40 @@ class TokenClassificationPipeline(Pipeline): ...@@ -106,15 +138,40 @@ class TokenClassificationPipeline(Pipeline):
self._basic_tokenizer = BasicTokenizer(do_lower_case=False) self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
self._args_parser = args_parser self._args_parser = args_parser
self.ignore_labels = ignore_labels self.ignore_labels = ignore_labels
self.grouped_entities = grouped_entities
self.ignore_subwords = ignore_subwords
if self.ignore_subwords and not self.tokenizer.is_fast: if aggregation_strategy is None:
aggregation_strategy = AggregationStrategy.NONE
if grouped_entities is not None or ignore_subwords is not None:
if grouped_entities and ignore_subwords:
aggregation_strategy = AggregationStrategy.FIRST
elif grouped_entities and not ignore_subwords:
aggregation_strategy = AggregationStrategy.SIMPLE
else:
aggregation_strategy = AggregationStrategy.NONE
if grouped_entities is not None:
warnings.warn(
f'`grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.'
)
if ignore_subwords is not None:
warnings.warn(
f'`ignore_subwords` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy="{aggregation_strategy}"` instead.'
)
if isinstance(aggregation_strategy, str):
aggregation_strategy = AggregationStrategy[aggregation_strategy.upper()]
if (
aggregation_strategy in {AggregationStrategy.FIRST, AggregationStrategy.MAX, AggregationStrategy.AVERAGE}
and not self.tokenizer.is_fast
):
raise ValueError( raise ValueError(
"Slow tokenizers cannot ignore subwords. Please set the `ignore_subwords` option" "Slow tokenizers cannot handle subwords. Please set the `aggregation_strategy` option"
"to `False` or use a fast tokenizer." 'to `"simple"` or use a fast tokenizer.'
) )
self.aggregation_strategy = aggregation_strategy
def __call__(self, inputs: Union[str, List[str]], **kwargs): def __call__(self, inputs: Union[str, List[str]], **kwargs):
""" """
Classify each token of the text(s) given as inputs. Classify each token of the text(s) given as inputs.
...@@ -125,14 +182,14 @@ class TokenClassificationPipeline(Pipeline): ...@@ -125,14 +182,14 @@ class TokenClassificationPipeline(Pipeline):
Return: Return:
A list or a list of list of :obj:`dict`: Each result comes as a list of dictionaries (one for each token in A list or a list of list of :obj:`dict`: Each result comes as a list of dictionaries (one for each token in
the corresponding input, or each entity if this pipeline was instantiated with the corresponding input, or each entity if this pipeline was instantiated with an aggregation_strategy)
:obj:`grouped_entities=True`) with the following keys: with the following keys:
- **word** (:obj:`str`) -- The token/word classified. - **word** (:obj:`str`) -- The token/word classified.
- **score** (:obj:`float`) -- The corresponding probability for :obj:`entity`. - **score** (:obj:`float`) -- The corresponding probability for :obj:`entity`.
- **entity** (:obj:`str`) -- The entity predicted for that token/word (it is named `entity_group` when - **entity** (:obj:`str`) -- The entity predicted for that token/word (it is named `entity_group` when
`grouped_entities` is set to True. `aggregation_strategy` is not :obj:`"none"`.
- **index** (:obj:`int`, only present when ``self.grouped_entities=False``) -- The index of the - **index** (:obj:`int`, only present when ``aggregation_strategy="none"``) -- The index of the
corresponding token in the sentence. corresponding token in the sentence.
- **start** (:obj:`int`, `optional`) -- The index of the start of the corresponding entity in the sentence. - **start** (:obj:`int`, `optional`) -- The index of the start of the corresponding entity in the sentence.
Only exists if the offsets are available within the tokenizer Only exists if the offsets are available within the tokenizer
...@@ -176,57 +233,141 @@ class TokenClassificationPipeline(Pipeline): ...@@ -176,57 +233,141 @@ class TokenClassificationPipeline(Pipeline):
entities = self.model(**tokens)[0][0].cpu().numpy() entities = self.model(**tokens)[0][0].cpu().numpy()
input_ids = tokens["input_ids"].cpu().numpy()[0] input_ids = tokens["input_ids"].cpu().numpy()[0]
score = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True) scores = np.exp(entities) / np.exp(entities).sum(-1, keepdims=True)
labels_idx = score.argmax(axis=-1) pre_entities = self.gather_pre_entities(sentence, input_ids, scores, offset_mapping, special_tokens_mask)
grouped_entities = self.aggregate(pre_entities, self.aggregation_strategy)
entities = [] # Filter anything that is in self.ignore_labels
# Filter to labels not in `self.ignore_labels` entities = [
# Filter special_tokens entity
filtered_labels_idx = [ for entity in grouped_entities
(idx, label_idx) if entity.get("entity", None) not in self.ignore_labels
for idx, label_idx in enumerate(labels_idx) and entity.get("entity_group", None) not in self.ignore_labels
if (self.model.config.id2label[label_idx] not in self.ignore_labels) and not special_tokens_mask[idx]
] ]
answers.append(entities)
for idx, label_idx in filtered_labels_idx: if len(answers) == 1:
return answers[0]
return answers
def gather_pre_entities(
self,
sentence: str,
input_ids: np.ndarray,
scores: np.ndarray,
offset_mapping: Optional[List[Tuple[int, int]]],
special_tokens_mask: np.ndarray,
) -> List[dict]:
"""Fuse various numpy arrays into dicts with all the information needed for aggregation"""
pre_entities = []
for idx, token_scores in enumerate(scores):
# Filter special_tokens, they should only occur
# at the sentence boundaries since we're not encoding pairs of
# sentences so we don't have to keep track of those.
if special_tokens_mask[idx]:
continue
word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx]))
if offset_mapping is not None: if offset_mapping is not None:
start_ind, end_ind = offset_mapping[idx] start_ind, end_ind = offset_mapping[idx]
word_ref = sentence[start_ind:end_ind] word_ref = sentence[start_ind:end_ind]
word = self.tokenizer.convert_ids_to_tokens([int(input_ids[idx])])[0]
is_subword = len(word_ref) != len(word) is_subword = len(word_ref) != len(word)
if int(input_ids[idx]) == self.tokenizer.unk_token_id: if int(input_ids[idx]) == self.tokenizer.unk_token_id:
word = word_ref word = word_ref
is_subword = False is_subword = False
else: else:
word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx]))
start_ind = None start_ind = None
end_ind = None end_ind = None
is_subword = False
entity = { pre_entity = {
"word": word, "word": word,
"score": score[idx][label_idx].item(), "scores": token_scores,
"entity": self.model.config.id2label[label_idx],
"index": idx,
"start": start_ind, "start": start_ind,
"end": end_ind, "end": end_ind,
"index": idx,
"is_subword": is_subword,
} }
pre_entities.append(pre_entity)
return pre_entities
if self.grouped_entities and self.ignore_subwords: def aggregate(self, pre_entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]:
entity["is_subword"] = is_subword if aggregation_strategy in {AggregationStrategy.NONE, AggregationStrategy.SIMPLE}:
entities = []
for pre_entity in pre_entities:
entity_idx = pre_entity["scores"].argmax()
score = pre_entity["scores"][entity_idx]
entity = {
"entity": self.model.config.id2label[entity_idx],
"score": score,
"index": pre_entity["index"],
"word": pre_entity["word"],
"start": pre_entity["start"],
"end": pre_entity["end"],
}
entities.append(entity)
else:
entities = self.aggregate_words(pre_entities, aggregation_strategy)
if aggregation_strategy == AggregationStrategy.NONE:
return entities
return self.group_entities(entities)
def aggregate_word(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> dict:
word = self.tokenizer.convert_tokens_to_string([entity["word"] for entity in entities])
if aggregation_strategy == AggregationStrategy.FIRST:
scores = entities[0]["scores"]
idx = scores.argmax()
score = scores[idx]
entity = self.model.config.id2label[idx]
elif aggregation_strategy == AggregationStrategy.MAX:
max_entity = max(entities, key=lambda entity: entity["scores"].max())
scores = max_entity["scores"]
idx = scores.argmax()
score = scores[idx]
entity = self.model.config.id2label[idx]
elif aggregation_strategy == AggregationStrategy.AVERAGE:
scores = np.stack([entity["scores"] for entity in entities])
average_scores = np.nanmean(scores, axis=0)
entity_idx = average_scores.argmax()
entity = self.model.config.id2label[entity_idx]
score = average_scores[entity_idx]
else:
raise ValueError("Invalid aggregation_strategy")
new_entity = {
"entity": entity,
"score": score,
"word": word,
"start": entities[0]["start"],
"end": entities[-1]["end"],
}
return new_entity
entities += [entity] def aggregate_words(self, entities: List[dict], aggregation_strategy: AggregationStrategy) -> List[dict]:
"""
Override tokens from a given word that disagree to force agreement on word boundaries.
if self.grouped_entities: Example: micro|soft| com|pany| B-ENT I-NAME I-ENT I-ENT will be rewritten with first strategy as microsoft|
answers += [self.group_entities(entities)] company| B-ENT I-ENT
# Append ungrouped entities """
else: assert aggregation_strategy not in {
answers += [entities] AggregationStrategy.NONE,
AggregationStrategy.SIMPLE,
}, "NONE and SIMPLE strategies are invalid"
if len(answers) == 1: word_entities = []
return answers[0] word_group = None
return answers for entity in entities:
if word_group is None:
word_group = [entity]
elif entity["is_subword"]:
word_group.append(entity)
else:
word_entities.append(self.aggregate_word(word_group, aggregation_strategy))
word_group = [entity]
# Last item
word_entities.append(self.aggregate_word(word_group, aggregation_strategy))
return word_entities
def group_sub_entities(self, entities: List[dict]) -> dict: def group_sub_entities(self, entities: List[dict]) -> dict:
""" """
...@@ -249,6 +390,19 @@ class TokenClassificationPipeline(Pipeline): ...@@ -249,6 +390,19 @@ class TokenClassificationPipeline(Pipeline):
} }
return entity_group return entity_group
def get_tag(self, entity_name: str) -> Tuple[str, str]:
if entity_name.startswith("B-"):
bi = "B"
tag = entity_name[2:]
elif entity_name.startswith("I-"):
bi = "I"
tag = entity_name[2:]
else:
# It's not in B-, I- format
bi = "B"
tag = entity_name
return bi, tag
def group_entities(self, entities: List[dict]) -> List[dict]: def group_entities(self, entities: List[dict]) -> List[dict]:
""" """
Find and group together the adjacent tokens with the same entity predicted. Find and group together the adjacent tokens with the same entity predicted.
...@@ -260,45 +414,29 @@ class TokenClassificationPipeline(Pipeline): ...@@ -260,45 +414,29 @@ class TokenClassificationPipeline(Pipeline):
entity_groups = [] entity_groups = []
entity_group_disagg = [] entity_group_disagg = []
if entities:
last_idx = entities[-1]["index"]
for entity in entities: for entity in entities:
is_last_idx = entity["index"] == last_idx
is_subword = self.ignore_subwords and entity["is_subword"]
if not entity_group_disagg: if not entity_group_disagg:
entity_group_disagg += [entity] entity_group_disagg.append(entity)
if is_last_idx:
entity_groups += [self.group_sub_entities(entity_group_disagg)]
continue continue
# If the current entity is similar and adjacent to the previous entity, append it to the disaggregated entity group # If the current entity is similar and adjacent to the previous entity,
# The split is meant to account for the "B" and "I" suffixes # append it to the disaggregated entity group
# The split is meant to account for the "B" and "I" prefixes
# Shouldn't merge if both entities are B-type # Shouldn't merge if both entities are B-type
if ( bi, tag = self.get_tag(entity["entity"])
( last_bi, last_tag = self.get_tag(entity_group_disagg[-1]["entity"])
entity["entity"].split("-")[-1] == entity_group_disagg[-1]["entity"].split("-")[-1]
and entity["entity"].split("-")[0] != "B" if tag == last_tag and bi != "B":
)
and entity["index"] == entity_group_disagg[-1]["index"] + 1
) or is_subword:
# Modify subword type to be previous_type # Modify subword type to be previous_type
if is_subword: entity_group_disagg.append(entity)
entity["entity"] = entity_group_disagg[-1]["entity"].split("-")[-1]
entity["score"] = np.nan # set ignored scores to nan and use np.nanmean
entity_group_disagg += [entity]
# Group the entities at the last entity
if is_last_idx:
entity_groups += [self.group_sub_entities(entity_group_disagg)]
# If the current entity is different from the previous entity, aggregate the disaggregated entity group
else: else:
entity_groups += [self.group_sub_entities(entity_group_disagg)] # If the current entity is different from the previous entity
# aggregate the disaggregated entity group
entity_groups.append(self.group_sub_entities(entity_group_disagg))
entity_group_disagg = [entity] entity_group_disagg = [entity]
# If it's the last entity, add it to the entity groups if entity_group_disagg:
if is_last_idx: # it's the last entity, add it to the entity groups
entity_groups += [self.group_sub_entities(entity_group_disagg)] entity_groups.append(self.group_sub_entities(entity_group_disagg))
return entity_groups return entity_groups
......
...@@ -1207,19 +1207,25 @@ def nested_simplify(obj, decimals=3): ...@@ -1207,19 +1207,25 @@ def nested_simplify(obj, decimals=3):
Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test
within tests. within tests.
""" """
import numpy as np
from transformers.tokenization_utils import BatchEncoding from transformers.tokenization_utils import BatchEncoding
if isinstance(obj, list): if isinstance(obj, list):
return [nested_simplify(item, decimals) for item in obj] return [nested_simplify(item, decimals) for item in obj]
elif isinstance(obj, np.ndarray):
return nested_simplify(obj.tolist())
elif isinstance(obj, (dict, BatchEncoding)): elif isinstance(obj, (dict, BatchEncoding)):
return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()} return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
elif isinstance(obj, (str, int)): elif isinstance(obj, (str, int, np.int64)):
return obj return obj
elif is_torch_available() and isinstance(obj, torch.Tensor): elif is_torch_available() and isinstance(obj, torch.Tensor):
return nested_simplify(obj.tolist()) return nested_simplify(obj.tolist(), decimals)
elif is_tf_available() and tf.is_tensor(obj): elif is_tf_available() and tf.is_tensor(obj):
return nested_simplify(obj.numpy().tolist()) return nested_simplify(obj.numpy().tolist())
elif isinstance(obj, float): elif isinstance(obj, float):
return round(obj, decimals) return round(obj, decimals)
elif isinstance(obj, np.float32):
return nested_simplify(obj.item(), decimals)
else: else:
raise Exception(f"Not supported: {type(obj)}") raise Exception(f"Not supported: {type(obj)}")
...@@ -14,15 +14,14 @@ ...@@ -14,15 +14,14 @@
import unittest import unittest
from transformers import AutoTokenizer, is_torch_available, pipeline import numpy as np
from transformers.pipelines import Pipeline, TokenClassificationArgumentHandler
from transformers.testing_utils import require_tf, require_torch, slow
from .test_pipelines_common import CustomInputPipelineCommonMixin from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
from transformers.pipelines import AggregationStrategy, Pipeline, TokenClassificationArgumentHandler
from transformers.testing_utils import nested_simplify, require_tf, require_torch, slow
from .test_pipelines_common import CustomInputPipelineCommonMixin
if is_torch_available():
import numpy as np
VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]] VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]]
...@@ -35,242 +34,333 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest. ...@@ -35,242 +34,333 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
large_models = [] # Models tested with the @slow decorator large_models = [] # Models tested with the @slow decorator
def _test_pipeline(self, nlp: Pipeline): def _test_pipeline(self, nlp: Pipeline):
output_keys = {"entity", "word", "score", "start", "end"} output_keys = {"entity", "word", "score", "start", "end", "index"}
if nlp.grouped_entities: if nlp.aggregation_strategy != AggregationStrategy.NONE:
output_keys = {"entity_group", "word", "score", "start", "end"} output_keys = {"entity_group", "word", "score", "start", "end"}
ungrouped_ner_inputs = [ self.assertIsNotNone(nlp)
mono_result = nlp(VALID_INPUTS[0])
self.assertIsInstance(mono_result, list)
self.assertIsInstance(mono_result[0], (dict, list))
if isinstance(mono_result[0], list):
mono_result = mono_result[0]
for key in output_keys:
self.assertIn(key, mono_result[0])
multi_result = [nlp(input) for input in VALID_INPUTS]
self.assertIsInstance(multi_result, list)
self.assertIsInstance(multi_result[0], (dict, list))
if isinstance(multi_result[0], list):
multi_result = multi_result[0]
for result in multi_result:
for key in output_keys:
self.assertIn(key, result)
@require_torch
@slow
def test_spanish_bert(self):
# https://github.com/huggingface/transformers/pull/4987
NER_MODEL = "mrm8488/bert-spanish-cased-finetuned-ner"
model = AutoModelForTokenClassification.from_pretrained(NER_MODEL)
tokenizer = AutoTokenizer.from_pretrained(NER_MODEL, use_fast=True)
sentence = """Consuelo Araújo Noguera, ministra de cultura del presidente Andrés Pastrana (1998.2002) fue asesinada por las Farc luego de haber permanecido secuestrada por algunos meses."""
token_classifier = pipeline("ner", model=model, tokenizer=tokenizer)
output = token_classifier(sentence)
self.assertEqual(
nested_simplify(output[:3]),
[ [
{ {"entity": "B-PER", "score": 0.999, "word": "Cons", "start": 0, "end": 4, "index": 1},
"entity": "B-PER", {"entity": "B-PER", "score": 0.803, "word": "##uelo", "start": 4, "end": 8, "index": 2},
"index": 1, {"entity": "I-PER", "score": 0.999, "word": "Ara", "start": 9, "end": 12, "index": 3},
"score": 0.9994944930076599, ],
"is_subword": False, )
"word": "Cons",
"start": 0, token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
"end": 4, output = token_classifier(sentence)
}, self.assertEqual(
{ nested_simplify(output[:3]),
"entity": "B-PER", [
"index": 2, {"entity_group": "PER", "score": 0.999, "word": "Cons", "start": 0, "end": 4},
"score": 0.8025449514389038, {"entity_group": "PER", "score": 0.966, "word": "##uelo Araújo Noguera", "start": 4, "end": 23},
"is_subword": True, {"entity_group": "PER", "score": 1.0, "word": "Andrés Pastrana", "start": 60, "end": 75},
"word": "##uelo", ],
"start": 4, )
"end": 8,
}, token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="first")
{ output = token_classifier(sentence)
"entity": "I-PER", self.assertEqual(
"index": 3, nested_simplify(output[:3]),
"score": 0.9993102550506592, [
"is_subword": False, {"entity_group": "PER", "score": 0.999, "word": "Consuelo Araújo Noguera", "start": 0, "end": 23},
"word": "Ara", {"entity_group": "PER", "score": 1.0, "word": "Andrés Pastrana", "start": 60, "end": 75},
"start": 9, {"entity_group": "ORG", "score": 0.999, "word": "Farc", "start": 110, "end": 114},
"end": 11, ],
}, )
{
"entity": "I-PER", token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="max")
"index": 4, output = token_classifier(sentence)
"score": 0.9993743896484375, self.assertEqual(
"is_subword": True, nested_simplify(output[:3]),
"word": "##új", [
"start": 11, {"entity_group": "PER", "score": 0.999, "word": "Consuelo Araújo Noguera", "start": 0, "end": 23},
"end": 13, {"entity_group": "PER", "score": 1.0, "word": "Andrés Pastrana", "start": 60, "end": 75},
}, {"entity_group": "ORG", "score": 0.999, "word": "Farc", "start": 110, "end": 114},
{ ],
"entity": "I-PER", )
"index": 5,
"score": 0.9992871880531311, token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="average")
"is_subword": True, output = token_classifier(sentence)
"word": "##o", self.assertEqual(
"start": 13, nested_simplify(output[:3]),
"end": 14, [
}, {"entity_group": "PER", "score": 0.966, "word": "Consuelo Araújo Noguera", "start": 0, "end": 23},
{ {"entity_group": "PER", "score": 1.0, "word": "Andrés Pastrana", "start": 60, "end": 75},
"entity": "I-PER", {"entity_group": "ORG", "score": 0.542, "word": "Farc", "start": 110, "end": 114},
"index": 6, ],
"score": 0.9993029236793518, )
"is_subword": False,
"word": "No", @require_torch
"start": 15, @slow
"end": 17, def test_dbmdz_english(self):
}, # Other sentence
{ NER_MODEL = "dbmdz/bert-large-cased-finetuned-conll03-english"
"entity": "I-PER", model = AutoModelForTokenClassification.from_pretrained(NER_MODEL)
"index": 7, tokenizer = AutoTokenizer.from_pretrained(NER_MODEL, use_fast=True)
"score": 0.9981776475906372, sentence = """Enzo works at the the UN"""
"is_subword": True, token_classifier = pipeline("ner", model=model, tokenizer=tokenizer)
"word": "##guera", output = token_classifier(sentence)
"start": 17, self.assertEqual(
"end": 22, nested_simplify(output),
}, [
{ {"entity": "I-PER", "score": 0.997, "word": "En", "start": 0, "end": 2, "index": 1},
"entity": "B-PER", {"entity": "I-PER", "score": 0.996, "word": "##zo", "start": 2, "end": 4, "index": 2},
"index": 15, {"entity": "I-ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24, "index": 7},
"score": 0.9998136162757874, ],
"is_subword": False, )
"word": "Andrés",
"start": 23, token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
"end": 28, output = token_classifier(sentence)
}, self.assertEqual(
{ nested_simplify(output),
"entity": "I-PER", [
"index": 16, {"entity_group": "PER", "score": 0.996, "word": "Enzo", "start": 0, "end": 4},
"score": 0.999740719795227, {"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24},
"is_subword": False, ],
"word": "Pas", )
"start": 29,
"end": 32, token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="first")
}, output = token_classifier(sentence)
{ self.assertEqual(
"entity": "I-PER", nested_simplify(output[:3]),
"index": 17, [
"score": 0.9997414350509644, {"entity_group": "PER", "score": 0.997, "word": "Enzo", "start": 0, "end": 4},
"is_subword": True, {"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24},
"word": "##tran", ],
"start": 32, )
"end": 36,
}, token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="max")
{ output = token_classifier(sentence)
"entity": "I-PER", self.assertEqual(
"index": 18, nested_simplify(output[:3]),
"score": 0.9996136426925659, [
"is_subword": True, {"entity_group": "PER", "score": 0.997, "word": "Enzo", "start": 0, "end": 4},
"word": "##a", {"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24},
"start": 36,
"end": 37,
},
{
"entity": "B-ORG",
"index": 28,
"score": 0.9989739060401917,
"is_subword": False,
"word": "Far",
"start": 39,
"end": 42,
},
{
"entity": "I-ORG",
"index": 29,
"score": 0.7188422083854675,
"is_subword": True,
"word": "##c",
"start": 42,
"end": 43,
},
], ],
)
token_classifier = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="average")
output = token_classifier(sentence)
self.assertEqual(
nested_simplify(output),
[ [
{"entity_group": "PER", "score": 0.996, "word": "Enzo", "start": 0, "end": 4},
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 22, "end": 24},
],
)
@require_torch
def test_aggregation_strategy(self):
model_name = self.small_models[0]
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt")
# Just to understand scores indexes in this test
self.assertEqual(
token_classifier.model.config.id2label,
{0: "O", 1: "B-MISC", 2: "I-MISC", 3: "B-PER", 4: "I-PER", 5: "B-ORG", 6: "I-ORG", 7: "B-LOC", 8: "I-LOC"},
)
example = [
{ {
"entity": "I-PER", # fmt : off
"scores": np.array([0, 0, 0, 0, 0.9968166351318359, 0, 0, 0]),
"index": 1, "index": 1,
"score": 0.9968166351318359,
"is_subword": False, "is_subword": False,
"word": "En", "word": "En",
"start": 0, "start": 0,
"end": 2, "end": 2,
}, },
{ {
"entity": "I-PER", # fmt : off
"scores": np.array([0, 0, 0, 0, 0.9957635998725891, 0, 0, 0]),
"index": 2, "index": 2,
"score": 0.9957635998725891,
"is_subword": True, "is_subword": True,
"word": "##zo", "word": "##zo",
"start": 2, "start": 2,
"end": 4, "end": 4,
}, },
{ {
"entity": "I-ORG", # fmt: off
"scores": np.array([0, 0, 0, 0, 0, 0.9986497163772583, 0, 0, ]),
# fmt: on
"index": 7, "index": 7,
"score": 0.9986497163772583,
"is_subword": False,
"word": "UN", "word": "UN",
"is_subword": False,
"start": 11, "start": 11,
"end": 13, "end": 13,
}, },
],
] ]
self.assertEqual(
expected_grouped_ner_results = [ nested_simplify(token_classifier.aggregate(example, AggregationStrategy.NONE)),
[ [
{ {"end": 2, "entity": "I-PER", "score": 0.997, "start": 0, "word": "En", "index": 1},
"entity_group": "PER", {"end": 4, "entity": "I-PER", "score": 0.996, "start": 2, "word": "##zo", "index": 2},
"score": 0.999369223912557, {"end": 13, "entity": "B-ORG", "score": 0.999, "start": 11, "word": "UN", "index": 7},
"word": "Consuelo Araújo Noguera",
"start": 0,
"end": 22,
},
{
"entity_group": "PER",
"score": 0.9997771680355072,
"word": "Andrés Pastrana",
"start": 23,
"end": 37,
},
{"entity_group": "ORG", "score": 0.9989739060401917, "word": "Farc", "start": 39, "end": 43},
], ],
)
self.assertEqual(
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.SIMPLE)),
[ [
{"entity_group": "PER", "score": 0.9968166351318359, "word": "Enzo", "start": 0, "end": 4}, {"entity_group": "PER", "score": 0.996, "word": "Enzo", "start": 0, "end": 4},
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN", "start": 11, "end": 13}, {"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 11, "end": 13},
], ],
] )
self.assertEqual(
expected_grouped_ner_results_w_subword = [ nested_simplify(token_classifier.aggregate(example, AggregationStrategy.FIRST)),
[
{"entity_group": "PER", "score": 0.997, "word": "Enzo", "start": 0, "end": 4},
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 11, "end": 13},
],
)
self.assertEqual(
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.MAX)),
[ [
{"entity_group": "PER", "score": 0.9994944930076599, "word": "Cons", "start": 0, "end": 4}, {"entity_group": "PER", "score": 0.997, "word": "Enzo", "start": 0, "end": 4},
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 11, "end": 13},
],
)
self.assertEqual(
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.AVERAGE)),
[
{"entity_group": "PER", "score": 0.996, "word": "Enzo", "start": 0, "end": 4},
{"entity_group": "ORG", "score": 0.999, "word": "UN", "start": 11, "end": 13},
],
)
@require_torch
def test_aggregation_strategy_example2(self):
model_name = self.small_models[0]
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
token_classifier = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt")
# Just to understand scores indexes in this test
self.assertEqual(
token_classifier.model.config.id2label,
{0: "O", 1: "B-MISC", 2: "I-MISC", 3: "B-PER", 4: "I-PER", 5: "B-ORG", 6: "I-ORG", 7: "B-LOC", 8: "I-LOC"},
)
example = [
{ {
"entity_group": "PER", # Necessary for AVERAGE
"score": 0.9663328925768534, "scores": np.array([0, 0.55, 0, 0.45, 0, 0, 0, 0, 0, 0]),
"word": "##uelo Araújo Noguera", "is_subword": False,
"start": 4, "index": 1,
"end": 22, "word": "Ra",
"start": 0,
"end": 2,
}, },
{ {
"entity_group": "PER", "scores": np.array([0, 0, 0, 0.2, 0, 0, 0, 0.8, 0, 0]),
"score": 0.9997273534536362, "is_subword": True,
"word": "Andrés Pastrana", "word": "##ma",
"start": 23, "start": 2,
"end": 37, "end": 4,
"index": 2,
}, },
{"entity_group": "ORG", "score": 0.8589080572128296, "word": "Farc", "start": 39, "end": 43}, {
], # 4th score will have the higher average
# 4th score is B-PER for this model
# It's does not correspond to any of the subtokens.
"scores": np.array([0, 0, 0, 0.4, 0, 0, 0.6, 0, 0, 0]),
"is_subword": True,
"word": "##zotti",
"start": 11,
"end": 13,
"index": 3,
},
]
self.assertEqual(
token_classifier.aggregate(example, AggregationStrategy.NONE),
[ [
{"entity_group": "PER", "score": 0.9962901175022125, "word": "Enzo", "start": 0, "end": 4}, {"end": 2, "entity": "B-MISC", "score": 0.55, "start": 0, "word": "Ra", "index": 1},
{"entity_group": "ORG", "score": 0.9986497163772583, "word": "UN", "start": 11, "end": 13}, {"end": 4, "entity": "B-LOC", "score": 0.8, "start": 2, "word": "##ma", "index": 2},
{"end": 13, "entity": "I-ORG", "score": 0.6, "start": 11, "word": "##zotti", "index": 3},
], ],
] )
self.assertIsNotNone(nlp)
mono_result = nlp(VALID_INPUTS[0])
self.assertIsInstance(mono_result, list)
self.assertIsInstance(mono_result[0], (dict, list))
if isinstance(mono_result[0], list): self.assertEqual(
mono_result = mono_result[0] token_classifier.aggregate(example, AggregationStrategy.FIRST),
[{"entity_group": "MISC", "score": 0.55, "word": "Ramazotti", "start": 0, "end": 13}],
)
self.assertEqual(
token_classifier.aggregate(example, AggregationStrategy.MAX),
[{"entity_group": "LOC", "score": 0.8, "word": "Ramazotti", "start": 0, "end": 13}],
)
self.assertEqual(
nested_simplify(token_classifier.aggregate(example, AggregationStrategy.AVERAGE)),
[{"entity_group": "PER", "score": 0.35, "word": "Ramazotti", "start": 0, "end": 13}],
)
for key in output_keys: @require_torch
self.assertIn(key, mono_result[0]) def test_gather_pre_entities(self):
multi_result = [nlp(input) for input in VALID_INPUTS] model_name = self.small_models[0]
self.assertIsInstance(multi_result, list) tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
self.assertIsInstance(multi_result[0], (dict, list)) nlp = pipeline(task="ner", model=model_name, tokenizer=tokenizer, framework="pt")
if isinstance(multi_result[0], list): sentence = "Hello there"
multi_result = multi_result[0]
for result in multi_result: tokens = tokenizer(
for key in output_keys: sentence,
self.assertIn(key, result) return_attention_mask=False,
return_tensors="pt",
truncation=True,
return_special_tokens_mask=True,
return_offsets_mapping=True,
)
offset_mapping = tokens.pop("offset_mapping").cpu().numpy()[0]
special_tokens_mask = tokens.pop("special_tokens_mask").cpu().numpy()[0]
input_ids = tokens["input_ids"].numpy()[0]
# First element in [CLS]
scores = np.array([[1, 0, 0], [0.1, 0.3, 0.6], [0.8, 0.1, 0.1]])
if nlp.grouped_entities: pre_entities = nlp.gather_pre_entities(sentence, input_ids, scores, offset_mapping, special_tokens_mask)
if nlp.ignore_subwords: self.assertEqual(
for ungrouped_input, grouped_result in zip(ungrouped_ner_inputs, expected_grouped_ner_results): nested_simplify(pre_entities),
self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result) [
else: {"word": "Hello", "scores": [0.1, 0.3, 0.6], "start": 0, "end": 5, "is_subword": False, "index": 1},
for ungrouped_input, grouped_result in zip( {
ungrouped_ner_inputs, expected_grouped_ner_results_w_subword "word": "there",
): "scores": [0.8, 0.1, 0.1],
self.assertEqual(nlp.group_entities(ungrouped_input), grouped_result) "index": 2,
"start": 6,
"end": 11,
"is_subword": False,
},
],
)
@require_tf @require_tf
def test_tf_only(self): def test_tf_only(self):
...@@ -295,8 +385,7 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest. ...@@ -295,8 +385,7 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
model=model_name, model=model_name,
tokenizer=tokenizer, tokenizer=tokenizer,
framework="tf", framework="tf",
grouped_entities=True, aggregation_strategy=AggregationStrategy.FIRST,
ignore_subwords=True,
) )
self._test_pipeline(nlp) self._test_pipeline(nlp)
...@@ -307,18 +396,23 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest. ...@@ -307,18 +396,23 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
model=model_name, model=model_name,
tokenizer=tokenizer, tokenizer=tokenizer,
framework="tf", framework="tf",
grouped_entities=True, aggregation_strategy=AggregationStrategy.SIMPLE,
ignore_subwords=False,
) )
self._test_pipeline(nlp) self._test_pipeline(nlp)
@require_torch @require_torch
def test_pt_ignore_subwords_slow_tokenizer_raises(self): def test_pt_ignore_subwords_slow_tokenizer_raises(self):
for model_name in self.small_models: model_name = self.small_models[0]
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
pipeline(task="ner", model=model_name, tokenizer=tokenizer, ignore_subwords=True, use_fast=False) pipeline(task="ner", model=model_name, tokenizer=tokenizer, aggregation_strategy=AggregationStrategy.FIRST)
with self.assertRaises(ValueError):
pipeline(
task="ner", model=model_name, tokenizer=tokenizer, aggregation_strategy=AggregationStrategy.AVERAGE
)
with self.assertRaises(ValueError):
pipeline(task="ner", model=model_name, tokenizer=tokenizer, aggregation_strategy=AggregationStrategy.MAX)
@require_torch @require_torch
def test_pt_defaults_slow_tokenizer(self): def test_pt_defaults_slow_tokenizer(self):
...@@ -333,27 +427,27 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest. ...@@ -333,27 +427,27 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
nlp = pipeline(task="ner", model=model_name) nlp = pipeline(task="ner", model=model_name)
self._test_pipeline(nlp) self._test_pipeline(nlp)
@slow
@require_torch
def test_warnings(self):
with self.assertWarns(UserWarning):
token_classifier = pipeline(task="ner", model=self.small_models[0], grouped_entities=True)
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.SIMPLE)
with self.assertWarns(UserWarning):
token_classifier = pipeline(
task="ner", model=self.small_models[0], grouped_entities=True, ignore_subwords=True
)
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.FIRST)
@slow @slow
@require_torch @require_torch
def test_simple(self): def test_simple(self):
nlp = pipeline(task="ner", model="dslim/bert-base-NER", grouped_entities=True) nlp = pipeline(task="ner", model="dslim/bert-base-NER", aggregation_strategy=AggregationStrategy.SIMPLE)
sentence = "Hello Sarah Jessica Parker who Jessica lives in New York" sentence = "Hello Sarah Jessica Parker who Jessica lives in New York"
sentence2 = "This is a simple test" sentence2 = "This is a simple test"
output = nlp(sentence) output = nlp(sentence)
def simplify(output): output_ = nested_simplify(output)
if isinstance(output, (list, tuple)):
return [simplify(item) for item in output]
elif isinstance(output, dict):
return {simplify(k): simplify(v) for k, v in output.items()}
elif isinstance(output, (str, int, np.int64)):
return output
elif isinstance(output, float):
return round(output, 3)
else:
raise Exception(f"Cannot handle {type(output)}")
output_ = simplify(output)
self.assertEqual( self.assertEqual(
output_, output_,
...@@ -371,7 +465,7 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest. ...@@ -371,7 +465,7 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
) )
output = nlp([sentence, sentence2]) output = nlp([sentence, sentence2])
output_ = simplify(output) output_ = nested_simplify(output)
self.assertEqual( self.assertEqual(
output_, output_,
...@@ -390,14 +484,14 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest. ...@@ -390,14 +484,14 @@ class TokenClassificationPipelineTests(CustomInputPipelineCommonMixin, unittest.
for model_name in self.small_models: for model_name in self.small_models:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
nlp = pipeline( nlp = pipeline(
task="ner", model=model_name, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=True task="ner", model=model_name, tokenizer=tokenizer, aggregation_strategy=AggregationStrategy.FIRST
) )
self._test_pipeline(nlp) self._test_pipeline(nlp)
for model_name in self.small_models: for model_name in self.small_models:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
nlp = pipeline( nlp = pipeline(
task="ner", model=model_name, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=False task="ner", model=model_name, tokenizer=tokenizer, aggregation_strategy=AggregationStrategy.SIMPLE
) )
self._test_pipeline(nlp) self._test_pipeline(nlp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment