"test/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "87ed70cd5fc0e88551899d07a2614def1964a84a"
Unverified Commit 891704b3 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Removing warning of model type for `microsoft/tapex-base-finetuned-wtq` (#18711)

and friends.
parent 84beb8a4
...@@ -16,14 +16,20 @@ from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Dataset, Pipeline, Pipeli ...@@ -16,14 +16,20 @@ from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Dataset, Pipeline, Pipeli
if is_torch_available(): if is_torch_available():
import torch import torch
from ..models.auto.modeling_auto import MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING from ..models.auto.modeling_auto import (
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
)
if is_tf_available() and is_tensorflow_probability_available(): if is_tf_available() and is_tensorflow_probability_available():
import tensorflow as tf import tensorflow as tf
import tensorflow_probability as tfp import tensorflow_probability as tfp
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING from ..models.auto.modeling_tf_auto import (
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
)
class TableQuestionAnsweringArgumentHandler(ArgumentHandler): class TableQuestionAnsweringArgumentHandler(ArgumentHandler):
...@@ -100,9 +106,14 @@ class TableQuestionAnsweringPipeline(Pipeline): ...@@ -100,9 +106,14 @@ class TableQuestionAnsweringPipeline(Pipeline):
self._args_parser = args_parser self._args_parser = args_parser
self.check_model_type( self.check_model_type(
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING dict(
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.items()
+ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items()
)
if self.framework == "tf" if self.framework == "tf"
else MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING else dict(
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.items() + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items()
)
) )
self.aggregate = bool(getattr(self.model.config, "aggregation_labels", None)) and bool( self.aggregate = bool(getattr(self.model.config, "aggregation_labels", None)) and bool(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment