"examples/vscode:/vscode.git/clone" did not exist on "c54419658b0459580573703bd0ab945b6beb1b5a"
Unverified Commit 0befb513 authored by Kevin Canwen Xu's avatar Kevin Canwen Xu Committed by GitHub
Browse files

Pipeline model type check (#5679)

* Add model type check for pipelines

* Add model type check for pipelines

* rename func

* Fix the init parameters

* Fix format

* rollback unnecessary refactor
parent dc31a72f
...@@ -46,6 +46,10 @@ if is_tf_available(): ...@@ -46,6 +46,10 @@ if is_tf_available():
TFAutoModelForQuestionAnswering, TFAutoModelForQuestionAnswering,
TFAutoModelForTokenClassification, TFAutoModelForTokenClassification,
TFAutoModelWithLMHead, TFAutoModelWithLMHead,
TF_MODEL_WITH_LM_HEAD_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
) )
if is_torch_available(): if is_torch_available():
...@@ -57,6 +61,11 @@ if is_torch_available(): ...@@ -57,6 +61,11 @@ if is_torch_available():
AutoModelForTokenClassification, AutoModelForTokenClassification,
AutoModelWithLMHead, AutoModelWithLMHead,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
MODEL_WITH_LM_HEAD_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -396,6 +405,7 @@ class Pipeline(_ScikitCompat): ...@@ -396,6 +405,7 @@ class Pipeline(_ScikitCompat):
if framework is None: if framework is None:
framework = get_framework() framework = get_framework()
self.task = task
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.modelcard = modelcard self.modelcard = modelcard
...@@ -469,6 +479,19 @@ class Pipeline(_ScikitCompat): ...@@ -469,6 +479,19 @@ class Pipeline(_ScikitCompat):
""" """
return {name: tensor.to(self.device) for name, tensor in inputs.items()} return {name: tensor.to(self.device) for name, tensor in inputs.items()}
def check_model_type(self, supported_models):
"""
Check if the model class is in the supported class list of the pipeline.
"""
if not isinstance(supported_models, list): # Create from a model mapping
supported_models = [item[1].__name__ for item in supported_models.items()]
if self.model.__class__.__name__ not in supported_models:
raise PipelineException(
self.task,
self.model.base_model_prefix,
f"The model '{self.model.__class__.__name__}' is not supported for {self.task}. Supported models are {supported_models}",
)
def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs): def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
""" """
Parse arguments and tokenize Parse arguments and tokenize
...@@ -615,6 +638,11 @@ class TextGenerationPipeline(Pipeline): ...@@ -615,6 +638,11 @@ class TextGenerationPipeline(Pipeline):
"TFCTRLLMHeadModel", "TFCTRLLMHeadModel",
] ]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.check_model_type(self.ALLOWED_MODELS)
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments # overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs): def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
...@@ -640,12 +668,6 @@ class TextGenerationPipeline(Pipeline): ...@@ -640,12 +668,6 @@ class TextGenerationPipeline(Pipeline):
def __call__( def __call__(
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
): ):
if self.model.__class__.__name__ not in self.ALLOWED_MODELS:
raise NotImplementedError(
"Generation is currently not supported for {}. Please select a model from {} for generation.".format(
self.model.__class__.__name__, self.ALLOWED_MODELS
)
)
text_inputs = self._args_parser(*args) text_inputs = self._args_parser(*args)
...@@ -771,6 +793,12 @@ class TextClassificationPipeline(Pipeline): ...@@ -771,6 +793,12 @@ class TextClassificationPipeline(Pipeline):
def __init__(self, return_all_scores: bool = False, **kwargs): def __init__(self, return_all_scores: bool = False, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.check_model_type(
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
if self.framework == "tf"
else MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
)
self.return_all_scores = return_all_scores self.return_all_scores = return_all_scores
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
...@@ -847,6 +875,8 @@ class FillMaskPipeline(Pipeline): ...@@ -847,6 +875,8 @@ class FillMaskPipeline(Pipeline):
task=task, task=task,
) )
self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_WITH_LM_HEAD_MAPPING)
self.topk = topk self.topk = topk
def ensure_exactly_one_mask_token(self, masked_index: np.ndarray): def ensure_exactly_one_mask_token(self, masked_index: np.ndarray):
...@@ -980,6 +1010,12 @@ class TokenClassificationPipeline(Pipeline): ...@@ -980,6 +1010,12 @@ class TokenClassificationPipeline(Pipeline):
task=task, task=task,
) )
self.check_model_type(
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
if self.framework == "tf"
else MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
)
self._basic_tokenizer = BasicTokenizer(do_lower_case=False) self._basic_tokenizer = BasicTokenizer(do_lower_case=False)
self.ignore_labels = ignore_labels self.ignore_labels = ignore_labels
self.grouped_entities = grouped_entities self.grouped_entities = grouped_entities
...@@ -1220,6 +1256,10 @@ class QuestionAnsweringPipeline(Pipeline): ...@@ -1220,6 +1256,10 @@ class QuestionAnsweringPipeline(Pipeline):
**kwargs, **kwargs,
) )
self.check_model_type(
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING if self.framework == "tf" else MODEL_FOR_QUESTION_ANSWERING_MAPPING
)
@staticmethod @staticmethod
def create_sample( def create_sample(
question: Union[str, List[str]], context: Union[str, List[str]] question: Union[str, List[str]], context: Union[str, List[str]]
...@@ -1483,9 +1523,13 @@ class SummarizationPipeline(Pipeline): ...@@ -1483,9 +1523,13 @@ class SummarizationPipeline(Pipeline):
on the associated CUDA device id. on the associated CUDA device id.
""" """
def __init__(self, **kwargs): def __init__(self, *args, **kwargs):
kwargs.update(task="summarization") kwargs.update(task="summarization")
super().__init__(**kwargs) super().__init__(*args, **kwargs)
self.check_model_type(
TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
)
def __call__( def __call__(
self, *documents, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs self, *documents, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
...@@ -1615,6 +1659,11 @@ class TranslationPipeline(Pipeline): ...@@ -1615,6 +1659,11 @@ class TranslationPipeline(Pipeline):
on the associated CUDA device id. on the associated CUDA device id.
""" """
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_WITH_LM_HEAD_MAPPING)
def __call__( def __call__(
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment