"docs/source/zh/run_scripts.md" did not exist on "eb849f6604c7dcc0e96d68f4851e52e253b9f0e5"
Unverified Commit 891704b3 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Removing warning of model type for `microsoft/tapex-base-finetuned-wtq` (#18711)

and friends.
parent 84beb8a4
......@@ -16,14 +16,20 @@ from .base import PIPELINE_INIT_ARGS, ArgumentHandler, Dataset, Pipeline, Pipeli
if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
from ..models.auto.modeling_auto import (
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
)
if is_tf_available() and is_tensorflow_probability_available():
import tensorflow as tf
import tensorflow_probability as tfp
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
from ..models.auto.modeling_tf_auto import (
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
)
class TableQuestionAnsweringArgumentHandler(ArgumentHandler):
......@@ -100,9 +106,14 @@ class TableQuestionAnsweringPipeline(Pipeline):
self._args_parser = args_parser
self.check_model_type(
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
dict(
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.items()
+ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items()
)
if self.framework == "tf"
else MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
else dict(
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.items() + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items()
)
)
self.aggregate = bool(getattr(self.model.config, "aggregation_labels", None)) and bool(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment