Unverified Commit 934ce3b8 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix MNIST byte flipping (#7081)

* fix MNIST byte flipping

* add test

* move to utils

* remove lazy import
parent 372f4fae
......@@ -7,7 +7,9 @@ import tarfile
import zipfile
import pytest
import torch
import torchvision.datasets.utils as utils
from common_utils import assert_equal
from torch._utils_internal import get_file_path_2
from torchvision.datasets.folder import make_dataset
from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS
......@@ -215,6 +217,24 @@ class TestDatasetsUtils:
pytest.raises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")
@pytest.mark.parametrize(
("dtype", "actual_hex", "expected_hex"),
[
(torch.uint8, "01 23 45 67 89 AB CD EF", "01 23 45 67 89 AB CD EF"),
(torch.float16, "01 23 45 67 89 AB CD EF", "23 01 67 45 AB 89 EF CD"),
(torch.int32, "01 23 45 67 89 AB CD EF", "67 45 23 01 EF CD AB 89"),
(torch.float64, "01 23 45 67 89 AB CD EF", "EF CD AB 89 67 45 23 01"),
],
)
def test_flip_byte_order(self, dtype, actual_hex, expected_hex):
def to_tensor(hex):
return torch.frombuffer(bytes.fromhex(hex), dtype=dtype)
assert_equal(
utils._flip_byte_order(to_tensor(actual_hex)),
to_tensor(expected_hex),
)
@pytest.mark.parametrize(
("kwargs", "expected_error_msg"),
......
......@@ -12,7 +12,7 @@ import numpy as np
import torch
from PIL import Image
from .utils import check_integrity, download_and_extract_archive, extract_archive, verify_str_arg
from .utils import _flip_byte_order, check_integrity, download_and_extract_archive, extract_archive, verify_str_arg
from .vision import VisionDataset
......@@ -519,13 +519,12 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso
torch_type = SN3_PASCALVINCENT_TYPEMAP[ty]
s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
num_bytes_per_value = torch.iinfo(torch_type).bits // 8
# The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
# we need to reverse the bytes before we can read them with torch.frombuffer().
needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1
parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1)))
if needs_byte_reversal:
parsed = parsed.flip(0)
# The MNIST format uses the big endian byte order, while `torch.frombuffer` uses whatever the system uses. In case
# that is little endian and the dtype has more than one byte, we need to flip them.
if sys.byteorder == "little" and parsed.element_size() > 1:
parsed = _flip_byte_order(parsed)
assert parsed.shape[0] == np.prod(s) or not strict
return parsed.view(*s)
......
......@@ -520,3 +520,9 @@ def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
data = np.flip(data, axis=1) # flip on h dimension
data = data[:slice_channels, :, :]
return data.astype(np.float32)
def _flip_byte_order(t: torch.Tensor) -> torch.Tensor:
return (
t.contiguous().view(torch.uint8).view(*t.shape, t.element_size()).flip(-1).view(*t.shape[:-1], -1).view(t.dtype)
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment