Unverified Commit 216b2f9e authored by Ankur Goyal's avatar Ankur Goyal Committed by GitHub
Browse files

Move the model type check (#19027)


Co-authored-by: default avatarAnkur Goyal <ankur@impira.com>
parent ea75e9f1
...@@ -116,16 +116,17 @@ class DocumentQuestionAnsweringPipeline(Pipeline): ...@@ -116,16 +116,17 @@ class DocumentQuestionAnsweringPipeline(Pipeline):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.check_model_type(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING)
if self.model.config.__class__.__name__ == "VisionEncoderDecoderConfig": if self.model.config.__class__.__name__ == "VisionEncoderDecoderConfig":
self.model_type = ModelType.VisionEncoderDecoder self.model_type = ModelType.VisionEncoderDecoder
if self.model.config.encoder.model_type != "donut-swin": if self.model.config.encoder.model_type != "donut-swin":
raise ValueError("Currently, the only supported VisionEncoderDecoder model is Donut") raise ValueError("Currently, the only supported VisionEncoderDecoder model is Donut")
elif self.model.config.__class__.__name__ == "LayoutLMConfig":
self.model_type = ModelType.LayoutLM
else: else:
self.model_type = ModelType.LayoutLMv2andv3 self.check_model_type(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING)
if self.model.config.__class__.__name__ == "LayoutLMConfig":
self.model_type = ModelType.LayoutLM
else:
self.model_type = ModelType.LayoutLMv2andv3
def _sanitize_parameters( def _sanitize_parameters(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment