Unverified Commit 6e5a83fb authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Fixed read_image bug (#3948)

* Fixed read_image bug

* Removed unused import

* Skip tests if cv2 unavailable

* Removed cv2 dependency
parent 96ad7f0c
......@@ -258,6 +258,42 @@ def test_write_file_non_ascii():
assert content == saved_content
@pytest.mark.parametrize('shape', [
(27, 27),
(60, 60),
(105, 105),
])
def test_read_1_bit_png(shape):
with get_tmp_dir() as root:
image_path = os.path.join(root, f'test_{shape}.png')
pixels = np.random.rand(*shape) > 0.5
img = Image.fromarray(pixels)
img.save(image_path)
img1 = read_image(image_path)
img2 = normalize_dimensions(torch.as_tensor(pixels * 255, dtype=torch.uint8))
assert_equal(img1, img2, check_stride=False)
@pytest.mark.parametrize('shape', [
(27, 27),
(60, 60),
(105, 105),
])
@pytest.mark.parametrize('mode', [
ImageReadMode.UNCHANGED,
ImageReadMode.GRAY,
])
def test_read_1_bit_png_consistency(shape, mode):
with get_tmp_dir() as root:
image_path = os.path.join(root, f'test_{shape}.png')
pixels = np.random.rand(*shape) > 0.5
img = Image.fromarray(pixels)
img.save(image_path)
img1 = read_image(image_path, mode)
img2 = read_image(image_path, mode)
assert_equal(img1, img2)
@needs_cuda
@pytest.mark.parametrize('img_path', [
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
......
......@@ -73,6 +73,9 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
int channels = png_get_channels(png_ptr, info_ptr);
if (color_type == PNG_COLOR_TYPE_GRAY && bit_depth < 8)
png_set_expand_gray_1_2_4_to_8(png_ptr);
if (mode != IMAGE_READ_MODE_UNCHANGED) {
// TODO: consider supporting PNG_INFO_tRNS
bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0;
......@@ -155,10 +158,9 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
auto tensor =
torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8);
auto ptr = tensor.accessor<uint8_t, 3>().data();
auto bytes = png_get_rowbytes(png_ptr, info_ptr);
for (png_uint_32 i = 0; i < height; ++i) {
png_read_row(png_ptr, ptr, nullptr);
ptr += bytes;
ptr += width * channels;
}
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
return tensor.permute({2, 0, 1});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment