Unverified Commit 83bfdbdd authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Migrating conversational pipeline tests to new testing format (#13114)

* New test format for conversational.

* Putting back old mixin.

* Re-enabling auto tests with LazyLoading.

* Feature extraction tests.

* Remove feature-extraction.

* Feature extraction with feature_extractor (No pun intended).

* Update check_model_type for fill-mask.
parent 72eefb34
...@@ -498,6 +498,15 @@ class _LazyAutoMapping(OrderedDict): ...@@ -498,6 +498,15 @@ class _LazyAutoMapping(OrderedDict):
if key in self._model_mapping.keys() if key in self._model_mapping.keys()
] ]
def get(self, key, default):
try:
return self.__getitem__(key)
except KeyError:
return default
def __bool__(self):
return bool(self.keys())
def values(self): def values(self):
return [ return [
self._load_attr_from_module(key, name) self._load_attr_from_module(key, name)
......
...@@ -748,7 +748,7 @@ class Pipeline(_ScikitCompat): ...@@ -748,7 +748,7 @@ class Pipeline(_ScikitCompat):
Parse arguments and tokenize Parse arguments and tokenize
""" """
# Parse arguments # Parse arguments
if self.tokenizer.pad_token is None: if getattr(self.tokenizer, "pad_token", None) is None:
padding = False padding = False
inputs = self.tokenizer( inputs = self.tokenizer(
inputs, inputs,
......
...@@ -159,6 +159,8 @@ class Conversation: ...@@ -159,6 +159,8 @@ class Conversation:
r""" r"""
min_length_for_response (:obj:`int`, `optional`, defaults to 32): min_length_for_response (:obj:`int`, `optional`, defaults to 32):
The minimum length (in number of tokens) for a response. The minimum length (in number of tokens) for a response.
minimum_tokens (:obj:`int`, `optional`, defaults to 10):
The minimum length of tokens to leave for a response.
""", """,
) )
class ConversationalPipeline(Pipeline): class ConversationalPipeline(Pipeline):
...@@ -188,15 +190,16 @@ class ConversationalPipeline(Pipeline): ...@@ -188,15 +190,16 @@ class ConversationalPipeline(Pipeline):
conversational_pipeline([conversation_1, conversation_2]) conversational_pipeline([conversation_1, conversation_2])
""" """
def __init__(self, min_length_for_response=32, *args, **kwargs): def __init__(self, min_length_for_response=32, minimum_tokens=10, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# We need at least an eos_token # We need at least an eos_token
assert self.tokenizer.eos_token_id is not None, "ConversationalPipeline tokenizer should have an EOS token set" # assert self.tokenizer.eos_token_id is not None, "ConversationalPipeline tokenizer should have an EOS token set"
if self.tokenizer.pad_token_id is None: if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token = self.tokenizer.eos_token
self.min_length_for_response = min_length_for_response self.min_length_for_response = min_length_for_response
self.minimum_tokens = minimum_tokens
def __call__( def __call__(
self, self,
...@@ -251,6 +254,16 @@ class ConversationalPipeline(Pipeline): ...@@ -251,6 +254,16 @@ class ConversationalPipeline(Pipeline):
elif self.framework == "tf": elif self.framework == "tf":
input_length = tf.shape(inputs["input_ids"])[-1].numpy() input_length = tf.shape(inputs["input_ids"])[-1].numpy()
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
n = inputs["input_ids"].shape[1]
if max_length - self.minimum_tokens < n:
logger.warning(
f"Conversation input is to long ({n}), trimming it to ({max_length} - {self.minimum_tokens})"
)
trim = max_length - self.minimum_tokens
inputs["input_ids"] = inputs["input_ids"][:, -trim:]
inputs["attention_mask"] = inputs["attention_mask"][:, -trim:]
generated_responses = self.model.generate( generated_responses = self.model.generate(
inputs["input_ids"], inputs["input_ids"],
attention_mask=inputs["attention_mask"], attention_mask=inputs["attention_mask"],
...@@ -324,10 +337,13 @@ class ConversationalPipeline(Pipeline): ...@@ -324,10 +337,13 @@ class ConversationalPipeline(Pipeline):
eos_token_id = self.tokenizer.eos_token_id eos_token_id = self.tokenizer.eos_token_id
input_ids = [] input_ids = []
for is_user, text in conversation.iter_texts(): for is_user, text in conversation.iter_texts():
input_ids.extend(self.tokenizer.encode(text, add_special_tokens=False) + [eos_token_id]) if eos_token_id is not None:
input_ids.extend(self.tokenizer.encode(text, add_special_tokens=False) + [eos_token_id])
else:
input_ids.extend(self.tokenizer.encode(text, add_special_tokens=False))
if len(input_ids) > self.tokenizer.model_max_length: if len(input_ids) > self.tokenizer.model_max_length:
input_ids = input_ids[-self.model_max_length :] input_ids = input_ids[-self.tokenizer.model_max_length :]
return input_ids return input_ids
def _parse_and_tokenize(self, conversations: List[Conversation]) -> Dict[str, Any]: def _parse_and_tokenize(self, conversations: List[Conversation]) -> Dict[str, Any]:
......
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
from ..feature_extraction_utils import PreTrainedFeatureExtractor
from ..modelcard import ModelCard from ..modelcard import ModelCard
from ..tokenization_utils import PreTrainedTokenizer from ..tokenization_utils import PreTrainedTokenizer
from .base import ArgumentHandler, Pipeline from .base import ArgumentHandler, Pipeline
...@@ -52,6 +53,7 @@ class FeatureExtractionPipeline(Pipeline): ...@@ -52,6 +53,7 @@ class FeatureExtractionPipeline(Pipeline):
self, self,
model: Union["PreTrainedModel", "TFPreTrainedModel"], model: Union["PreTrainedModel", "TFPreTrainedModel"],
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
feature_extractor: Optional[PreTrainedFeatureExtractor] = None,
modelcard: Optional[ModelCard] = None, modelcard: Optional[ModelCard] = None,
framework: Optional[str] = None, framework: Optional[str] = None,
args_parser: ArgumentHandler = None, args_parser: ArgumentHandler = None,
...@@ -61,6 +63,7 @@ class FeatureExtractionPipeline(Pipeline): ...@@ -61,6 +63,7 @@ class FeatureExtractionPipeline(Pipeline):
super().__init__( super().__init__(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
feature_extractor=feature_extractor,
modelcard=modelcard, modelcard=modelcard,
framework=framework, framework=framework,
args_parser=args_parser, args_parser=args_parser,
......
...@@ -18,7 +18,7 @@ if TYPE_CHECKING: ...@@ -18,7 +18,7 @@ if TYPE_CHECKING:
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from ..models.auto.modeling_tf_auto import TF_MODEL_WITH_LM_HEAD_MAPPING from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_MASKED_LM_MAPPING
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -81,7 +81,9 @@ class FillMaskPipeline(Pipeline): ...@@ -81,7 +81,9 @@ class FillMaskPipeline(Pipeline):
task=task, task=task,
) )
self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING) self.check_model_type(
TF_MODEL_FOR_MASKED_LM_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING
)
self.top_k = top_k self.top_k = top_k
self.targets = targets self.targets = targets
if self.tokenizer.mask_token_id is None: if self.tokenizer.mask_token_id is None:
......
...@@ -19,7 +19,15 @@ from functools import lru_cache ...@@ -19,7 +19,15 @@ from functools import lru_cache
from typing import List, Optional from typing import List, Optional
from unittest import mock, skipIf from unittest import mock, skipIf
from transformers import TOKENIZER_MAPPING, AutoTokenizer, is_tf_available, is_torch_available, pipeline from transformers import (
FEATURE_EXTRACTOR_MAPPING,
TOKENIZER_MAPPING,
AutoFeatureExtractor,
AutoTokenizer,
is_tf_available,
is_torch_available,
pipeline,
)
from transformers.file_utils import to_py_obj from transformers.file_utils import to_py_obj
from transformers.pipelines import Pipeline from transformers.pipelines import Pipeline
from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow
...@@ -81,6 +89,16 @@ def get_tiny_tokenizer_from_checkpoint(checkpoint): ...@@ -81,6 +89,16 @@ def get_tiny_tokenizer_from_checkpoint(checkpoint):
return tokenizer return tokenizer
def get_tiny_feature_extractor_from_checkpoint(checkpoint, tiny_config):
try:
feature_extractor = AutoFeatureExtractor.from_pretrained(checkpoint)
except Exception:
feature_extractor = None
if hasattr(tiny_config, "image_size") and feature_extractor:
feature_extractor = feature_extractor.__class__(size=tiny_config.image_size, crop_size=tiny_config.image_size)
return feature_extractor
class ANY: class ANY:
def __init__(self, _type): def __init__(self, _type):
self._type = _type self._type = _type
...@@ -94,10 +112,14 @@ class ANY: ...@@ -94,10 +112,14 @@ class ANY:
class PipelineTestCaseMeta(type): class PipelineTestCaseMeta(type):
def __new__(mcs, name, bases, dct): def __new__(mcs, name, bases, dct):
def gen_test(ModelClass, checkpoint, tiny_config, tokenizer_class): def gen_test(ModelClass, checkpoint, tiny_config, tokenizer_class, feature_extractor_class):
@skipIf(tiny_config is None, "TinyConfig does not exist") @skipIf(tiny_config is None, "TinyConfig does not exist")
@skipIf(checkpoint is None, "checkpoint does not exist") @skipIf(checkpoint is None, "checkpoint does not exist")
def test(self): def test(self):
if ModelClass.__name__.endswith("ForCausalLM"):
tiny_config.is_encoder_decoder = False
if ModelClass.__name__.endswith("WithLMHead"):
tiny_config.is_decoder = True
model = ModelClass(tiny_config) model = ModelClass(tiny_config)
if hasattr(model, "eval"): if hasattr(model, "eval"):
model = model.eval() model = model.eval()
...@@ -110,7 +132,8 @@ class PipelineTestCaseMeta(type): ...@@ -110,7 +132,8 @@ class PipelineTestCaseMeta(type):
# provide some default tokenizer and hope for the best. # provide some default tokenizer and hope for the best.
except: # noqa: E722 except: # noqa: E722
self.skipTest(f"Ignoring {ModelClass}, cannot create a simple tokenizer") self.skipTest(f"Ignoring {ModelClass}, cannot create a simple tokenizer")
self.run_pipeline_test(model, tokenizer) feature_extractor = get_tiny_feature_extractor_from_checkpoint(checkpoint, tiny_config)
self.run_pipeline_test(model, tokenizer, feature_extractor)
return test return test
...@@ -125,10 +148,24 @@ class PipelineTestCaseMeta(type): ...@@ -125,10 +148,24 @@ class PipelineTestCaseMeta(type):
checkpoint = get_checkpoint_from_architecture(model_architecture) checkpoint = get_checkpoint_from_architecture(model_architecture)
tiny_config = get_tiny_config_from_class(configuration) tiny_config = get_tiny_config_from_class(configuration)
tokenizer_classes = TOKENIZER_MAPPING.get(configuration, []) tokenizer_classes = TOKENIZER_MAPPING.get(configuration, [])
feature_extractor_class = FEATURE_EXTRACTOR_MAPPING.get(configuration, None)
for tokenizer_class in tokenizer_classes: for tokenizer_class in tokenizer_classes:
if tokenizer_class is not None and tokenizer_class.__name__.endswith("Fast"): if tokenizer_class is not None and tokenizer_class.__name__.endswith("Fast"):
test_name = f"test_{prefix}_{configuration.__name__}_{model_architecture.__name__}_{tokenizer_class.__name__}"
dct[test_name] = gen_test(model_architecture, checkpoint, tiny_config, tokenizer_class) tokenizer_name = tokenizer_class.__name__ if tokenizer_class else "notokenizer"
feature_extractor_name = (
feature_extractor_class.__name__
if feature_extractor_class
else "nofeature_extractor"
)
test_name = f"test_{prefix}_{configuration.__name__}_{model_architecture.__name__}_{tokenizer_name}_{feature_extractor_name}"
dct[test_name] = gen_test(
model_architecture,
checkpoint,
tiny_config,
tokenizer_class,
feature_extractor_class,
)
return type.__new__(mcs, name, bases, dct) return type.__new__(mcs, name, bases, dct)
......
...@@ -15,6 +15,10 @@ ...@@ -15,6 +15,10 @@
import unittest import unittest
from transformers import ( from transformers import (
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoTokenizer, AutoTokenizer,
...@@ -22,137 +26,81 @@ from transformers import ( ...@@ -22,137 +26,81 @@ from transformers import (
BlenderbotSmallTokenizer, BlenderbotSmallTokenizer,
Conversation, Conversation,
ConversationalPipeline, ConversationalPipeline,
is_torch_available,
pipeline, pipeline,
) )
from transformers.testing_utils import is_pipeline_test, require_torch, slow, torch_device from transformers.testing_utils import is_pipeline_test, require_torch, slow, torch_device
from .test_pipelines_common import MonoInputPipelineCommonMixin from .test_pipelines_common import ANY, PipelineTestCaseMeta
if is_torch_available():
import torch
from torch import nn
from transformers.models.gpt2 import GPT2Config, GPT2LMHeadModel
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0 DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
@is_pipeline_test @is_pipeline_test
class SimpleConversationPipelineTests(unittest.TestCase): class ConversationalPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
def get_pipeline(self): model_mapping = dict(
# When list(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items())
config = GPT2Config( if MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
vocab_size=263, else [] + list(MODEL_FOR_CAUSAL_LM_MAPPING.items())
n_ctx=128, if MODEL_FOR_CAUSAL_LM_MAPPING
max_length=128, else []
n_embd=64, )
n_layer=1, tf_model_mapping = dict(
n_head=8, list(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items())
bos_token_id=256, if TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
eos_token_id=257, else [] + list(TF_MODEL_FOR_CAUSAL_LM_MAPPING.items())
) if TF_MODEL_FOR_CAUSAL_LM_MAPPING
model = GPT2LMHeadModel(config) else []
# Force model output to be L )
V, D = model.lm_head.weight.shape
bias = torch.zeros(V) def run_pipeline_test(self, model, tokenizer, feature_extractor):
bias[76] = 1 conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer)
weight = torch.zeros((V, D), requires_grad=True) # Simple
outputs = conversation_agent(Conversation("Hi there!"))
model.lm_head.bias = nn.Parameter(bias) self.assertEqual(outputs, Conversation(past_user_inputs=["Hi there!"], generated_responses=[ANY(str)]))
model.lm_head.weight = nn.Parameter(weight)
# # Created with:
# import tempfile
# from tokenizers import Tokenizer, models
# from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
# vocab = [(chr(i), i) for i in range(256)]
# tokenizer = Tokenizer(models.Unigram(vocab))
# with tempfile.NamedTemporaryFile() as f:
# tokenizer.save(f.name)
# real_tokenizer = PreTrainedTokenizerFast(tokenizer_file=f.name, eos_token="<eos>", bos_token="<bos>")
# real_tokenizer._tokenizer.save("dummy.json")
# Special tokens are automatically added at load time.
tokenizer = AutoTokenizer.from_pretrained("Narsil/small_conversational_test")
conversation_agent = pipeline(
task="conversational", device=DEFAULT_DEVICE_NUM, model=model, tokenizer=tokenizer
)
return conversation_agent
@require_torch # Single list
def test_integration_torch_conversation(self): outputs = conversation_agent([Conversation("Hi there!")])
conversation_agent = self.get_pipeline() self.assertEqual(outputs, Conversation(past_user_inputs=["Hi there!"], generated_responses=[ANY(str)]))
# Batch
conversation_1 = Conversation("Going to the movies tonight - any suggestions?") conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
conversation_2 = Conversation("What's the last book you have read?") conversation_2 = Conversation("What's the last book you have read?")
self.assertEqual(len(conversation_1.past_user_inputs), 0) self.assertEqual(len(conversation_1.past_user_inputs), 0)
self.assertEqual(len(conversation_2.past_user_inputs), 0) self.assertEqual(len(conversation_2.past_user_inputs), 0)
result = conversation_agent([conversation_1, conversation_2], max_length=48) outputs = conversation_agent([conversation_1, conversation_2])
self.assertEqual(outputs, [conversation_1, conversation_2])
# Two conversations in one pass
self.assertEqual(result, [conversation_1, conversation_2])
self.assertEqual( self.assertEqual(
result, outputs,
[ [
Conversation( Conversation(
None,
past_user_inputs=["Going to the movies tonight - any suggestions?"], past_user_inputs=["Going to the movies tonight - any suggestions?"],
generated_responses=["L"], generated_responses=[ANY(str)],
),
Conversation(
None, past_user_inputs=["What's the last book you have read?"], generated_responses=["L"]
), ),
Conversation(past_user_inputs=["What's the last book you have read?"], generated_responses=[ANY(str)]),
], ],
) )
# One conversation with history # One conversation with history
conversation_2.add_user_input("Why do you recommend it?") conversation_2.add_user_input("Why do you recommend it?")
result = conversation_agent(conversation_2, max_length=64) outputs = conversation_agent(conversation_2)
self.assertEqual(outputs, conversation_2)
self.assertEqual(result, conversation_2)
self.assertEqual( self.assertEqual(
result, outputs,
Conversation( Conversation(
None,
past_user_inputs=["What's the last book you have read?", "Why do you recommend it?"], past_user_inputs=["What's the last book you have read?", "Why do you recommend it?"],
generated_responses=["L", "L"], generated_responses=[ANY(str), ANY(str)],
), ),
) )
with self.assertRaises(ValueError):
conversation_agent("Hi there!")
class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase): with self.assertRaises(ValueError):
pipeline_task = "conversational" conversation_agent(Conversation())
small_models = [] # Models tested without the @slow decorator
large_models = ["microsoft/DialoGPT-medium"] # Models tested with the @slow decorator
invalid_inputs = ["Hi there!", Conversation()]
def _test_pipeline(
self, conversation_agent
): # override the default test method to check that the output is a `Conversation` object
self.assertIsNotNone(conversation_agent)
# We need to recreate conversation for successive tests to pass as
# Conversation objects get *consumed* by the pipeline
conversation = Conversation("Hi there!")
mono_result = conversation_agent(conversation)
self.assertIsInstance(mono_result, Conversation)
conversations = [Conversation("Hi there!"), Conversation("How are you?")]
multi_result = conversation_agent(conversations)
self.assertIsInstance(multi_result, list)
self.assertIsInstance(multi_result[0], Conversation)
# Conversation have been consumed and are not valid anymore # Conversation have been consumed and are not valid anymore
# Inactive conversations passed to the pipeline raise a ValueError # Inactive conversations passed to the pipeline raise a ValueError
self.assertRaises(ValueError, conversation_agent, conversation) with self.assertRaises(ValueError):
self.assertRaises(ValueError, conversation_agent, conversations) conversation_agent(conversation_2)
for bad_input in self.invalid_inputs:
self.assertRaises(Exception, conversation_agent, bad_input)
self.assertRaises(Exception, conversation_agent, self.invalid_inputs)
@require_torch @require_torch
@slow @slow
......
...@@ -61,13 +61,15 @@ class FeatureExtractionPipelineTests(unittest.TestCase, metaclass=PipelineTestCa ...@@ -61,13 +61,15 @@ class FeatureExtractionPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
raise ValueError("We expect lists of floats, nothing else") raise ValueError("We expect lists of floats, nothing else")
return shape return shape
def run_pipeline_test(self, model, tokenizer): def run_pipeline_test(self, model, tokenizer, feature_extractor):
if isinstance(model.config, LxmertConfig): if isinstance(model.config, LxmertConfig):
# This is an bimodal model, we need to find a more consistent way # This is an bimodal model, we need to find a more consistent way
# to switch on those models. # to switch on those models.
return return
feature_extractor = FeatureExtractionPipeline(model=model, tokenizer=tokenizer) feature_extractor = FeatureExtractionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
)
if feature_extractor.model.config.is_encoder_decoder: if feature_extractor.model.config.is_encoder_decoder:
# encoder_decoder models are trickier for this pipeline. # encoder_decoder models are trickier for this pipeline.
# Do we want encoder + decoder inputs to get some featues? # Do we want encoder + decoder inputs to get some featues?
......
...@@ -159,16 +159,16 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): ...@@ -159,16 +159,16 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
unmasker = pipeline(task="fill-mask", model="sshleifer/tiny-distilroberta-base", framework="pt") unmasker = pipeline(task="fill-mask", model="sshleifer/tiny-distilroberta-base", framework="pt")
unmasker.tokenizer.pad_token_id = None unmasker.tokenizer.pad_token_id = None
unmasker.tokenizer.pad_token = None unmasker.tokenizer.pad_token = None
self.run_pipeline_test(unmasker.model, unmasker.tokenizer) self.run_pipeline_test(unmasker.model, unmasker.tokenizer, None)
@require_tf @require_tf
def test_model_no_pad_tf(self): def test_model_no_pad_tf(self):
unmasker = pipeline(task="fill-mask", model="sshleifer/tiny-distilroberta-base", framework="tf") unmasker = pipeline(task="fill-mask", model="sshleifer/tiny-distilroberta-base", framework="tf")
unmasker.tokenizer.pad_token_id = None unmasker.tokenizer.pad_token_id = None
unmasker.tokenizer.pad_token = None unmasker.tokenizer.pad_token = None
self.run_pipeline_test(unmasker.model, unmasker.tokenizer) self.run_pipeline_test(unmasker.model, unmasker.tokenizer, None)
def run_pipeline_test(self, model, tokenizer): def run_pipeline_test(self, model, tokenizer, feature_extractor):
if tokenizer.mask_token_id is None: if tokenizer.mask_token_id is None:
self.skipTest("The provided tokenizer has no mask token, (probably reformer)") self.skipTest("The provided tokenizer has no mask token, (probably reformer)")
......
...@@ -72,7 +72,7 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC ...@@ -72,7 +72,7 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC
outputs = text_classifier("Birds are a type of animal") outputs = text_classifier("Birds are a type of animal")
self.assertEqual(nested_simplify(outputs), [{"label": "POSITIVE", "score": 0.988}]) self.assertEqual(nested_simplify(outputs), [{"label": "POSITIVE", "score": 0.988}])
def run_pipeline_test(self, model, tokenizer): def run_pipeline_test(self, model, tokenizer, feature_extractor):
text_classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer) text_classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer)
# Small inputs because BartTokenizer tiny has maximum position embeddings = 22 # Small inputs because BartTokenizer tiny has maximum position embeddings = 22
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment