Unverified Commit c63fcabf authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

[Large PR] Entire rework of pipelines. (#13308)



* Enabling dataset iteration on pipelines.

Enabling dataset iteration on pipelines.

Unifying parameters under `set_parameters` function.

Small fix.

Last fixes after rebase

Remove print.

Fixing text2text `generate_kwargs`

No more `self.max_length`.

Fixing tf only conversational.

Consistency in start/stop index over TF/PT.

Speeding up drastically on TF (nasty bug where max_length would increase
a ton.)

Adding test for support for non fast tokenizers.

Fixign GPU usage on zero-shot.

Fix working on Tf.

Update src/transformers/pipelines/base.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

Update src/transformers/pipelines/base.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

Small cleanup.

Remove all asserts + simple format.

* Fixing audio-classification for large PR.

* Overly explicity null checking.

* Encapsulating GPU/CPU pytorch manipulation directly within `base.py`.

* Removed internal state for parameters of the  pipeline.

Instead of overriding implicitly internal state, we moved
to real named arguments on every `preprocess`, `_forward`,
`postprocess` function.

Instead `_sanitize_parameters` will be used to split all kwargs
of both __init__ and __call__ into the 3 kinds of named parameters.

* Move import warnings.

* Small fixes.

* Quality.

* Another small fix, using the CI to debug faster.

* Last fixes.

* Last fix.

* Small cleanup of tensor moving.

* is not None.

* Adding a bunch of docs + a iteration test.

* Fixing doc style.

* KeyDataset = None guard.

* RRemoving the Cuda test for pipelines (was testing).

* Even more simple iteration test.

* Correct import .

* Long day.

* Fixes in docs.

* [WIP] migrating object detection.

* Fixed the target_size bug.

* Fixup.

* Bad variable name.

* Fixing `ensure_on_device` respects original ModelOutput.
parent 09549aa1
...@@ -1343,6 +1343,8 @@ def nested_simplify(obj, decimals=3): ...@@ -1343,6 +1343,8 @@ def nested_simplify(obj, decimals=3):
return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()} return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
elif isinstance(obj, (str, int, np.int64)): elif isinstance(obj, (str, int, np.int64)):
return obj return obj
elif obj is None:
return obj
elif is_torch_available() and isinstance(obj, torch.Tensor): elif is_torch_available() and isinstance(obj, torch.Tensor):
return nested_simplify(obj.tolist(), decimals) return nested_simplify(obj.tolist(), decimals)
elif is_tf_available() and tf.is_tensor(obj): elif is_tf_available() and tf.is_tensor(obj):
......
...@@ -15,11 +15,13 @@ ...@@ -15,11 +15,13 @@
import importlib import importlib
import logging import logging
import string import string
import unittest
from abc import abstractmethod from abc import abstractmethod
from functools import lru_cache from functools import lru_cache
from unittest import skipIf from unittest import skipIf
from transformers import FEATURE_EXTRACTOR_MAPPING, TOKENIZER_MAPPING, AutoFeatureExtractor, AutoTokenizer from transformers import FEATURE_EXTRACTOR_MAPPING, TOKENIZER_MAPPING, AutoFeatureExtractor, AutoTokenizer, pipeline
from transformers.testing_utils import is_pipeline_test, require_torch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -177,3 +179,30 @@ class PipelineTestCaseMeta(type): ...@@ -177,3 +179,30 @@ class PipelineTestCaseMeta(type):
dct["test_small_model_tf"] = dct.get("test_small_model_tf", inner) dct["test_small_model_tf"] = dct.get("test_small_model_tf", inner)
return type.__new__(mcs, name, bases, dct) return type.__new__(mcs, name, bases, dct)
@is_pipeline_test
class CommonPipelineTest(unittest.TestCase):
@require_torch
def test_pipeline_iteration(self):
from torch.utils.data import Dataset
class MyDataset(Dataset):
data = [
"This is a test",
"This restaurant is great",
"This restaurant is awful",
]
def __len__(self):
return 3
def __getitem__(self, i):
return self.data[i]
text_classifier = pipeline(
task="text-classification", model="Narsil/tiny-distilbert-sequence-classification", framework="pt"
)
dataset = MyDataset()
for output in text_classifier(dataset):
self.assertEqual(output, {"label": ANY(str), "score": ANY(float)})
...@@ -187,24 +187,15 @@ class ConversationalPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM ...@@ -187,24 +187,15 @@ class ConversationalPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer) conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer)
conversation_1 = Conversation("hello") conversation_1 = Conversation("hello")
inputs = conversation_agent._parse_and_tokenize([conversation_1]) inputs = conversation_agent.preprocess(conversation_1)
self.assertEqual(inputs["input_ids"].tolist(), [[31373, 50256]]) self.assertEqual(inputs["input_ids"].tolist(), [[31373, 50256]])
conversation_2 = Conversation("how are you ?", past_user_inputs=["hello"], generated_responses=["Hi there!"]) conversation_2 = Conversation("how are you ?", past_user_inputs=["hello"], generated_responses=["Hi there!"])
inputs = conversation_agent._parse_and_tokenize([conversation_2]) inputs = conversation_agent.preprocess(conversation_2)
self.assertEqual( self.assertEqual(
inputs["input_ids"].tolist(), [[31373, 50256, 17250, 612, 0, 50256, 4919, 389, 345, 5633, 50256]] inputs["input_ids"].tolist(), [[31373, 50256, 17250, 612, 0, 50256, 4919, 389, 345, 5633, 50256]]
) )
inputs = conversation_agent._parse_and_tokenize([conversation_1, conversation_2])
self.assertEqual(
inputs["input_ids"].tolist(),
[
[31373, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256],
[31373, 50256, 17250, 612, 0, 50256, 4919, 389, 345, 5633, 50256],
],
)
@require_torch @require_torch
@slow @slow
def test_integration_torch_conversation_blenderbot_400M_input_ids(self): def test_integration_torch_conversation_blenderbot_400M_input_ids(self):
...@@ -214,7 +205,7 @@ class ConversationalPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM ...@@ -214,7 +205,7 @@ class ConversationalPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
# test1 # test1
conversation_1 = Conversation("hello") conversation_1 = Conversation("hello")
inputs = conversation_agent._parse_and_tokenize([conversation_1]) inputs = conversation_agent.preprocess(conversation_1)
self.assertEqual(inputs["input_ids"].tolist(), [[1710, 86, 2]]) self.assertEqual(inputs["input_ids"].tolist(), [[1710, 86, 2]])
# test2 # test2
...@@ -225,7 +216,7 @@ class ConversationalPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM ...@@ -225,7 +216,7 @@ class ConversationalPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
" Do you like lasagne? It is a traditional Italian dish consisting of a shepherd's pie." " Do you like lasagne? It is a traditional Italian dish consisting of a shepherd's pie."
], ],
) )
inputs = conversation_agent._parse_and_tokenize([conversation_1]) inputs = conversation_agent.preprocess(conversation_1)
self.assertEqual( self.assertEqual(
inputs["input_ids"].tolist(), inputs["input_ids"].tolist(),
[ [
...@@ -271,7 +262,7 @@ class ConversationalPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM ...@@ -271,7 +262,7 @@ class ConversationalPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
964, 964,
21, 21,
2, # EOS 2, # EOS
] ],
], ],
) )
......
...@@ -91,6 +91,8 @@ class FeatureExtractionPipelineTests(unittest.TestCase, metaclass=PipelineTestCa ...@@ -91,6 +91,8 @@ class FeatureExtractionPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
shape = self.get_shape(outputs) shape = self.get_shape(outputs)
self.assertEqual(shape[0], 1) self.assertEqual(shape[0], 1)
outputs = feature_extractor(["This is a test", "Another test"]) # If we send too small input
# there's a bug within FunnelModel (output with shape [1, 4, 2, 1] doesn't match the broadcast shape [1, 4, 2, 2])
outputs = feature_extractor(["This is a test", "Another longer test"])
shape = self.get_shape(outputs) shape = self.get_shape(outputs)
self.assertEqual(shape[0], 2) self.assertEqual(shape[0], 2)
...@@ -186,7 +186,7 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): ...@@ -186,7 +186,7 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
], ],
) )
outputs = fill_masker([f"This is a {tokenizer.mask_token}", f"Another {tokenizer.mask_token}"]) outputs = fill_masker([f"This is a {tokenizer.mask_token}", f"Another {tokenizer.mask_token} great test."])
self.assertEqual( self.assertEqual(
outputs, outputs,
[ [
......
...@@ -116,8 +116,8 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase ...@@ -116,8 +116,8 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
self.assertEqual( self.assertEqual(
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}}, {"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}}, {"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
], ],
) )
...@@ -133,12 +133,12 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase ...@@ -133,12 +133,12 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
[ [
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}}, {"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}}, {"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
], ],
[ [
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}}, {"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
{"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 266, "ymin": 200, "xmax": 799, "ymax": 599}}, {"score": 0.3432, "label": "LABEL_0", "box": {"xmin": 160, "ymin": 120, "xmax": 480, "ymax": 359}},
], ],
], ],
) )
...@@ -156,11 +156,11 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase ...@@ -156,11 +156,11 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
self.assertEqual( self.assertEqual(
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}}, {"score": 0.9982, "label": "remote", "box": {"xmin": 40, "ymin": 70, "xmax": 175, "ymax": 117}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}}, {"score": 0.9960, "label": "remote", "box": {"xmin": 333, "ymin": 72, "xmax": 368, "ymax": 187}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}}, {"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 639, "ymax": 473}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}}, {"score": 0.9988, "label": "cat", "box": {"xmin": 13, "ymin": 52, "xmax": 314, "ymax": 470}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}}, {"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
], ],
) )
...@@ -174,18 +174,18 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase ...@@ -174,18 +174,18 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
[ [
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}}, {"score": 0.9982, "label": "remote", "box": {"xmin": 40, "ymin": 70, "xmax": 175, "ymax": 117}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}}, {"score": 0.9960, "label": "remote", "box": {"xmin": 333, "ymin": 72, "xmax": 368, "ymax": 187}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}}, {"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 639, "ymax": 473}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}}, {"score": 0.9988, "label": "cat", "box": {"xmin": 13, "ymin": 52, "xmax": 314, "ymax": 470}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}}, {"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
], ],
[ [
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}}, {"score": 0.9982, "label": "remote", "box": {"xmin": 40, "ymin": 70, "xmax": 175, "ymax": 117}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}}, {"score": 0.9960, "label": "remote", "box": {"xmin": 333, "ymin": 72, "xmax": 368, "ymax": 187}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}}, {"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 639, "ymax": 473}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}}, {"score": 0.9988, "label": "cat", "box": {"xmin": 13, "ymin": 52, "xmax": 314, "ymax": 470}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}}, {"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
], ],
], ],
) )
...@@ -201,11 +201,11 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase ...@@ -201,11 +201,11 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
self.assertEqual( self.assertEqual(
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}}, {"score": 0.9982, "label": "remote", "box": {"xmin": 40, "ymin": 70, "xmax": 175, "ymax": 117}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}}, {"score": 0.9960, "label": "remote", "box": {"xmin": 333, "ymin": 72, "xmax": 368, "ymax": 187}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}}, {"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 639, "ymax": 473}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}}, {"score": 0.9988, "label": "cat", "box": {"xmin": 13, "ymin": 52, "xmax": 314, "ymax": 470}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}}, {"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
], ],
) )
...@@ -219,18 +219,18 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase ...@@ -219,18 +219,18 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
[ [
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}}, {"score": 0.9982, "label": "remote", "box": {"xmin": 40, "ymin": 70, "xmax": 175, "ymax": 117}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}}, {"score": 0.9960, "label": "remote", "box": {"xmin": 333, "ymin": 72, "xmax": 368, "ymax": 187}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}}, {"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 639, "ymax": 473}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}}, {"score": 0.9988, "label": "cat", "box": {"xmin": 13, "ymin": 52, "xmax": 314, "ymax": 470}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}}, {"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
], ],
[ [
{"score": 0.9982, "label": "remote", "box": {"xmin": 66, "ymin": 118, "xmax": 292, "ymax": 196}}, {"score": 0.9982, "label": "remote", "box": {"xmin": 40, "ymin": 70, "xmax": 175, "ymax": 117}},
{"score": 0.9960, "label": "remote", "box": {"xmin": 555, "ymin": 120, "xmax": 613, "ymax": 312}}, {"score": 0.9960, "label": "remote", "box": {"xmin": 333, "ymin": 72, "xmax": 368, "ymax": 187}},
{"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 1065, "ymax": 789}}, {"score": 0.9955, "label": "couch", "box": {"xmin": 0, "ymin": 1, "xmax": 639, "ymax": 473}},
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}}, {"score": 0.9988, "label": "cat", "box": {"xmin": 13, "ymin": 52, "xmax": 314, "ymax": 470}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}}, {"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
], ],
], ],
) )
...@@ -247,7 +247,7 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase ...@@ -247,7 +247,7 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
self.assertEqual( self.assertEqual(
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
{"score": 0.9988, "label": "cat", "box": {"xmin": 22, "ymin": 86, "xmax": 523, "ymax": 784}}, {"score": 0.9988, "label": "cat", "box": {"xmin": 13, "ymin": 52, "xmax": 314, "ymax": 470}},
{"score": 0.9987, "label": "cat", "box": {"xmin": 575, "ymin": 39, "xmax": 1066, "ymax": 614}}, {"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
], ],
) )
...@@ -96,7 +96,7 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest ...@@ -96,7 +96,7 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
def run_aggregation_strategy(self, model, tokenizer): def run_aggregation_strategy(self, model, tokenizer):
token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer, aggregation_strategy="simple") token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer, aggregation_strategy="simple")
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.SIMPLE) self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.SIMPLE)
outputs = token_classifier("A simple string") outputs = token_classifier("A simple string")
self.assertIsInstance(outputs, list) self.assertIsInstance(outputs, list)
n = len(outputs) n = len(outputs)
...@@ -115,7 +115,7 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest ...@@ -115,7 +115,7 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
) )
token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer, aggregation_strategy="first") token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer, aggregation_strategy="first")
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.FIRST) self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.FIRST)
outputs = token_classifier("A simple string") outputs = token_classifier("A simple string")
self.assertIsInstance(outputs, list) self.assertIsInstance(outputs, list)
n = len(outputs) n = len(outputs)
...@@ -134,7 +134,7 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest ...@@ -134,7 +134,7 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
) )
token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer, aggregation_strategy="max") token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer, aggregation_strategy="max")
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.MAX) self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.MAX)
outputs = token_classifier("A simple string") outputs = token_classifier("A simple string")
self.assertIsInstance(outputs, list) self.assertIsInstance(outputs, list)
n = len(outputs) n = len(outputs)
...@@ -155,7 +155,7 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest ...@@ -155,7 +155,7 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
token_classifier = TokenClassificationPipeline( token_classifier = TokenClassificationPipeline(
model=model, tokenizer=tokenizer, aggregation_strategy="average" model=model, tokenizer=tokenizer, aggregation_strategy="average"
) )
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.AVERAGE) self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.AVERAGE)
outputs = token_classifier("A simple string") outputs = token_classifier("A simple string")
self.assertIsInstance(outputs, list) self.assertIsInstance(outputs, list)
n = len(outputs) n = len(outputs)
...@@ -175,12 +175,12 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest ...@@ -175,12 +175,12 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
with self.assertWarns(UserWarning): with self.assertWarns(UserWarning):
token_classifier = pipeline(task="ner", model=model, tokenizer=tokenizer, grouped_entities=True) token_classifier = pipeline(task="ner", model=model, tokenizer=tokenizer, grouped_entities=True)
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.SIMPLE) self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.SIMPLE)
with self.assertWarns(UserWarning): with self.assertWarns(UserWarning):
token_classifier = pipeline( token_classifier = pipeline(
task="ner", model=model, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=True task="ner", model=model, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=True
) )
self.assertEqual(token_classifier.aggregation_strategy, AggregationStrategy.FIRST) self.assertEqual(token_classifier._postprocess_params["aggregation_strategy"], AggregationStrategy.FIRST)
@require_torch @require_torch
@slow @slow
...@@ -533,7 +533,12 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest ...@@ -533,7 +533,12 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
scores = np.array([[1, 0, 0], [0.1, 0.3, 0.6], [0.8, 0.1, 0.1]]) scores = np.array([[1, 0, 0], [0.1, 0.3, 0.6], [0.8, 0.1, 0.1]])
pre_entities = token_classifier.gather_pre_entities( pre_entities = token_classifier.gather_pre_entities(
sentence, input_ids, scores, offset_mapping, special_tokens_mask sentence,
input_ids,
scores,
offset_mapping,
special_tokens_mask,
aggregation_strategy=AggregationStrategy.NONE,
) )
self.assertEqual( self.assertEqual(
nested_simplify(pre_entities), nested_simplify(pre_entities),
...@@ -570,6 +575,20 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest ...@@ -570,6 +575,20 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
], ],
) )
@require_torch
def test_no_offset_tokenizer(self):
model_name = "Narsil/small2"
tokenizer = AutoTokenizer.from_pretrained("Narsil/small2", use_fast=False)
token_classifier = pipeline(task="token-classification", model=model_name, tokenizer=tokenizer, framework="pt")
outputs = token_classifier("This is a test !")
self.assertEqual(
nested_simplify(outputs),
[
{"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": None, "end": None},
{"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": None, "end": None},
],
)
@require_torch @require_torch
def test_small_model_pt(self): def test_small_model_pt(self):
model_name = "Narsil/small2" model_name = "Narsil/small2"
......
...@@ -108,8 +108,8 @@ class TranslationNewFormatPipelineTests(unittest.TestCase): ...@@ -108,8 +108,8 @@ class TranslationNewFormatPipelineTests(unittest.TestCase):
# but we do for this one # but we do for this one
translator = pipeline(task="translation_en_to_de") translator = pipeline(task="translation_en_to_de")
self.assertEquals(translator.src_lang, "en") self.assertEqual(translator._preprocess_params["src_lang"], "en")
self.assertEquals(translator.tgt_lang, "de") self.assertEqual(translator._preprocess_params["tgt_lang"], "de")
@require_torch @require_torch
@slow @slow
...@@ -137,8 +137,8 @@ class TranslationNewFormatPipelineTests(unittest.TestCase): ...@@ -137,8 +137,8 @@ class TranslationNewFormatPipelineTests(unittest.TestCase):
def test_translation_on_odd_language(self): def test_translation_on_odd_language(self):
model = "patrickvonplaten/t5-tiny-random" model = "patrickvonplaten/t5-tiny-random"
translator = pipeline(task="translation_cn_to_ar", model=model) translator = pipeline(task="translation_cn_to_ar", model=model)
self.assertEquals(translator.src_lang, "cn") self.assertEqual(translator._preprocess_params["src_lang"], "cn")
self.assertEquals(translator.tgt_lang, "ar") self.assertEqual(translator._preprocess_params["tgt_lang"], "ar")
@require_torch @require_torch
def test_translation_default_language_selection(self): def test_translation_default_language_selection(self):
...@@ -146,8 +146,8 @@ class TranslationNewFormatPipelineTests(unittest.TestCase): ...@@ -146,8 +146,8 @@ class TranslationNewFormatPipelineTests(unittest.TestCase):
with pytest.warns(UserWarning, match=r".*translation_en_to_de.*"): with pytest.warns(UserWarning, match=r".*translation_en_to_de.*"):
translator = pipeline(task="translation", model=model) translator = pipeline(task="translation", model=model)
self.assertEqual(translator.task, "translation_en_to_de") self.assertEqual(translator.task, "translation_en_to_de")
self.assertEqual(translator.src_lang, "en") self.assertEqual(translator._preprocess_params["src_lang"], "en")
self.assertEqual(translator.tgt_lang, "de") self.assertEqual(translator._preprocess_params["tgt_lang"], "de")
@require_torch @require_torch
def test_translation_with_no_language_no_model_fails(self): def test_translation_with_no_language_no_model_fails(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment