Unverified Commit d066c373 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding support for LayoutLMvX variants for `object-detection`. (#20143)

* Adding support for LayoutLMvX variants for `object-detection`.

* Revert bogs `layoutlm` feature extractor which does not exist (it was a
V2 model) .

* Updated condition.

* Handling the comments.
parent 7ec1dc88
...@@ -345,7 +345,7 @@ SUPPORTED_TASKS = { ...@@ -345,7 +345,7 @@ SUPPORTED_TASKS = {
"tf": (), "tf": (),
"pt": (AutoModelForObjectDetection,) if is_torch_available() else (), "pt": (AutoModelForObjectDetection,) if is_torch_available() else (),
"default": {"model": {"pt": ("facebook/detr-resnet-50", "2729413")}}, "default": {"model": {"pt": ("facebook/detr-resnet-50", "2729413")}},
"type": "image", "type": "multimodal",
}, },
"zero-shot-object-detection": { "zero-shot-object-detection": {
"impl": ZeroShotObjectDetectionPipeline, "impl": ZeroShotObjectDetectionPipeline,
......
...@@ -11,7 +11,7 @@ if is_vision_available(): ...@@ -11,7 +11,7 @@ if is_vision_available():
if is_torch_available(): if is_torch_available():
import torch import torch
from ..models.auto.modeling_auto import MODEL_FOR_OBJECT_DETECTION_MAPPING from ..models.auto.modeling_auto import MODEL_FOR_OBJECT_DETECTION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -39,7 +39,9 @@ class ObjectDetectionPipeline(Pipeline): ...@@ -39,7 +39,9 @@ class ObjectDetectionPipeline(Pipeline):
raise ValueError(f"The {self.__class__} is only available in PyTorch.") raise ValueError(f"The {self.__class__} is only available in PyTorch.")
requires_backends(self, "vision") requires_backends(self, "vision")
self.check_model_type(MODEL_FOR_OBJECT_DETECTION_MAPPING) self.check_model_type(
dict(MODEL_FOR_OBJECT_DETECTION_MAPPING.items() + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items())
)
def _sanitize_parameters(self, **kwargs): def _sanitize_parameters(self, **kwargs):
postprocess_kwargs = {} postprocess_kwargs = {}
...@@ -82,6 +84,8 @@ class ObjectDetectionPipeline(Pipeline): ...@@ -82,6 +84,8 @@ class ObjectDetectionPipeline(Pipeline):
image = load_image(image) image = load_image(image)
target_size = torch.IntTensor([[image.height, image.width]]) target_size = torch.IntTensor([[image.height, image.width]])
inputs = self.feature_extractor(images=[image], return_tensors="pt") inputs = self.feature_extractor(images=[image], return_tensors="pt")
if self.tokenizer is not None:
inputs = self.tokenizer(text=inputs["words"], boxes=inputs["boxes"], return_tensors="pt")
inputs["target_size"] = target_size inputs["target_size"] = target_size
return inputs return inputs
...@@ -89,11 +93,39 @@ class ObjectDetectionPipeline(Pipeline): ...@@ -89,11 +93,39 @@ class ObjectDetectionPipeline(Pipeline):
target_size = model_inputs.pop("target_size") target_size = model_inputs.pop("target_size")
outputs = self.model(**model_inputs) outputs = self.model(**model_inputs)
model_outputs = outputs.__class__({"target_size": target_size, **outputs}) model_outputs = outputs.__class__({"target_size": target_size, **outputs})
if self.tokenizer is not None:
model_outputs["bbox"] = model_inputs["bbox"]
return model_outputs return model_outputs
def postprocess(self, model_outputs, threshold=0.9): def postprocess(self, model_outputs, threshold=0.9):
target_size = model_outputs["target_size"] target_size = model_outputs["target_size"]
raw_annotations = self.feature_extractor.post_process_object_detection(model_outputs, threshold, target_size) if self.tokenizer is not None:
# This is a LayoutLMForTokenClassification variant.
# The OCR got the boxes and the model classified the words.
width, height = target_size[0].tolist()
def unnormalize(bbox):
return self._get_bounding_box(
torch.Tensor(
[
(width * bbox[0] / 1000),
(height * bbox[1] / 1000),
(width * bbox[2] / 1000),
(height * bbox[3] / 1000),
]
)
)
scores, classes = model_outputs["logits"].squeeze(0).softmax(dim=-1).max(dim=-1)
labels = [self.model.config.id2label[prediction] for prediction in classes.tolist()]
boxes = [unnormalize(bbox) for bbox in model_outputs["bbox"].squeeze(0)]
keys = ["score", "label", "box"]
annotation = [dict(zip(keys, vals)) for vals in zip(scores.tolist(), labels, boxes) if vals[0] > threshold]
else:
# This is a regular ForObjectDetectionModel
raw_annotations = self.feature_extractor.post_process_object_detection(
model_outputs, threshold, target_size
)
raw_annotation = raw_annotations[0] raw_annotation = raw_annotations[0]
scores = raw_annotation["scores"] scores = raw_annotation["scores"]
labels = raw_annotation["labels"] labels = raw_annotation["labels"]
......
...@@ -243,3 +243,30 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase ...@@ -243,3 +243,30 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
{"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}}, {"score": 0.9987, "label": "cat", "box": {"xmin": 345, "ymin": 23, "xmax": 640, "ymax": 368}},
], ],
) )
@require_torch
@slow
def test_layoutlm(self):
model_id = "philschmid/layoutlm-funsd"
threshold = 0.998
object_detector = pipeline("object-detection", model=model_id, threshold=threshold)
outputs = object_detector(
"https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png"
)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{
"score": 0.9982,
"label": "B-QUESTION",
"box": {"xmin": 654, "ymin": 165, "xmax": 719, "ymax": 719},
},
{
"score": 0.9982,
"label": "I-QUESTION",
"box": {"xmin": 691, "ymin": 202, "xmax": 735, "ymax": 735},
},
],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment