Unverified Commit 323f28dc authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing `image-segmentation` tests. (#14223)

parent 7396095a
...@@ -126,13 +126,13 @@ class ImageSegmentationPipeline(Pipeline): ...@@ -126,13 +126,13 @@ class ImageSegmentationPipeline(Pipeline):
def _forward(self, model_inputs): def _forward(self, model_inputs):
target_size = model_inputs.pop("target_size") target_size = model_inputs.pop("target_size")
outputs = self.model(**model_inputs) model_outputs = self.model(**model_inputs)
model_outputs = {"outputs": outputs, "target_size": target_size} model_outputs["target_size"] = target_size
return model_outputs return model_outputs
def postprocess(self, model_outputs, threshold=0.9, mask_threshold=0.5): def postprocess(self, model_outputs, threshold=0.9, mask_threshold=0.5):
raw_annotations = self.feature_extractor.post_process_segmentation( raw_annotations = self.feature_extractor.post_process_segmentation(
model_outputs["outputs"], model_outputs["target_size"], threshold=threshold, mask_threshold=0.5 model_outputs, model_outputs["target_size"], threshold=threshold, mask_threshold=0.5
) )
raw_annotation = raw_annotations[0] raw_annotation = raw_annotations[0]
......
...@@ -51,13 +51,18 @@ else: ...@@ -51,13 +51,18 @@ else:
@require_timm @require_timm
@require_torch @require_torch
@is_pipeline_test @is_pipeline_test
@unittest.skip("Skip while fixing segmentation pipeline tests")
class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING
@require_datasets def get_test_pipeline(self, model, tokenizer, feature_extractor):
def run_pipeline_test(self, model, tokenizer, feature_extractor):
image_segmenter = ImageSegmentationPipeline(model=model, feature_extractor=feature_extractor) image_segmenter = ImageSegmentationPipeline(model=model, feature_extractor=feature_extractor)
return image_segmenter, [
"./tests/fixtures/tests_samples/COCO/000000039769.png",
"./tests/fixtures/tests_samples/COCO/000000039769.png",
]
@require_datasets
def run_pipeline_test(self, image_segmenter, examples):
outputs = image_segmenter("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0) outputs = image_segmenter("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0)
self.assertEqual(outputs, [{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12) self.assertEqual(outputs, [{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment