"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "1124d95dbb1a3512d3e80791d73d0f541d1d7e9f"
Unverified Commit 83a2e694 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Cast masks to np.unit8 before converting to PIL.Image.Image (#19616)

* Cast masks to np.unit8 before converting to PIL.Image.Image

* Update tests

* Fixup
parent 909f0709
...@@ -172,7 +172,7 @@ class ImageSegmentationPipeline(Pipeline): ...@@ -172,7 +172,7 @@ class ImageSegmentationPipeline(Pipeline):
for label in labels: for label in labels:
mask = (segmentation == label) * 255 mask = (segmentation == label) * 255
mask = Image.fromarray(mask, mode="L") mask = Image.fromarray(mask.astype(np.uint8), mode="L")
label = self.model.config.id2label[label] label = self.model.config.id2label[label]
annotation.append({"score": None, "label": label, "mask": mask}) annotation.append({"score": None, "label": label, "mask": mask})
else: else:
......
...@@ -226,15 +226,11 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa ...@@ -226,15 +226,11 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
self.assertEqual( self.assertEqual(
nested_simplify(outputs, decimals=4), nested_simplify(outputs, decimals=4),
[ [
{ {"score": None, "label": "LABEL_0", "mask": "42d09072282a32da2ac77375a4c1280f"},
"score": None,
"label": "LABEL_0",
"mask": "775518a7ed09eea888752176c6ba8f38",
},
{ {
"score": None, "score": None,
"label": "LABEL_1", "label": "LABEL_1",
"mask": "a12da23a46848128af68c63aa8ba7a02", "mask": "46b8cc3976732873b219f77a1213c1a5",
}, },
], ],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment