Unverified Commit 85b78580 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

decode_* returns contiguous tensors (#4898)

Description:
- Applied `contiguous` on decoded output tensor in decode_jpeg and decode_png ops
- Updated tests and docs

Related to #4880
parent 031e129b
...@@ -94,6 +94,8 @@ def test_decode_jpeg(img_path, pil_mode, mode): ...@@ -94,6 +94,8 @@ def test_decode_jpeg(img_path, pil_mode, mode):
data = read_file(img_path) data = read_file(img_path)
img_ljpeg = decode_image(data, mode=mode) img_ljpeg = decode_image(data, mode=mode)
assert img_ljpeg.is_contiguous()
# Permit a small variation on pixel values to account for implementation # Permit a small variation on pixel values to account for implementation
# differences between Pillow and LibJPEG. # differences between Pillow and LibJPEG.
abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item() abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
...@@ -173,6 +175,8 @@ def test_decode_png(img_path, pil_mode, mode): ...@@ -173,6 +175,8 @@ def test_decode_png(img_path, pil_mode, mode):
data = read_file(img_path) data = read_file(img_path)
img_lpng = decode_image(data, mode=mode) img_lpng = decode_image(data, mode=mode)
assert img_lpng.is_contiguous()
tol = 0 if pil_mode is None else 1 tol = 0 if pil_mode is None else 1
if PILLOW_VERSION >= (8, 3) and pil_mode == "LA": if PILLOW_VERSION >= (8, 3) and pil_mode == "LA":
......
...@@ -148,7 +148,7 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) { ...@@ -148,7 +148,7 @@ torch::Tensor decode_jpeg(const torch::Tensor& data, ImageReadMode mode) {
jpeg_finish_decompress(&cinfo); jpeg_finish_decompress(&cinfo);
jpeg_destroy_decompress(&cinfo); jpeg_destroy_decompress(&cinfo);
return tensor.permute({2, 0, 1}); return tensor.permute({2, 0, 1}).contiguous();
} }
#endif #endif
......
...@@ -224,7 +224,7 @@ torch::Tensor decode_png( ...@@ -224,7 +224,7 @@ torch::Tensor decode_png(
} }
} }
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
return tensor.permute({2, 0, 1}); return tensor.permute({2, 0, 1}).contiguous();
} }
#endif #endif
......
...@@ -59,7 +59,7 @@ def write_file(filename: str, data: torch.Tensor) -> None: ...@@ -59,7 +59,7 @@ def write_file(filename: str, data: torch.Tensor) -> None:
def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
""" """
Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor. Decodes a PNG image into a 3 dimensional RGB or grayscale contiguous Tensor.
Optionally converts the image to the desired format. Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255]. The values of the output tensor are uint8 in [0, 255].
...@@ -117,7 +117,7 @@ def decode_jpeg( ...@@ -117,7 +117,7 @@ def decode_jpeg(
input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, device: str = "cpu" input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, device: str = "cpu"
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Decodes a JPEG image into a 3 dimensional RGB or grayscale Tensor. Decodes a JPEG image into a 3 dimensional RGB or grayscale contiguous Tensor.
Optionally converts the image to the desired format. Optionally converts the image to the desired format.
The values of the output tensor are uint8 between 0 and 255. The values of the output tensor are uint8 between 0 and 255.
...@@ -185,7 +185,7 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75): ...@@ -185,7 +185,7 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
""" """
Detects whether an image is a JPEG or PNG and performs the appropriate Detects whether an image is a JPEG or PNG and performs the appropriate
operation to decode the image into a 3 dimensional RGB or grayscale Tensor. operation to decode the image into a 3 dimensional RGB or grayscale contiguous Tensor.
Optionally converts the image to the desired format. Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255]. The values of the output tensor are uint8 in [0, 255].
...@@ -207,7 +207,7 @@ def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN ...@@ -207,7 +207,7 @@ def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN
def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
""" """
Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor. Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale contiguous Tensor.
Optionally converts the image to the desired format. Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255]. The values of the output tensor are uint8 in [0, 255].
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment