Unverified Commit 8dcb5b81 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add prototype utilities to read arbitrary numeric binary files (#4882)

* add FloReader datapipe

* add NumericBinaryReader

* revert unrelated change

* cleanup

* cleanup

* add comment for byte reversal

* use numpy after all

* appease mypy

* use .astype() with copy=False

* add docstring and cleanuo

* reuse current _read_flo and revert MNIST changes

* cleanup

* revert demonstration

* refactor

* cleanup

* add support for mutable memory

* add test

* add comments

* catch more exceptions

* fix mypy

* fix variable names

* hardcode flow sizes in test

* add fix dtype docstring

* expand comment on different reading modes

* add comment about files in update mode

* add tests for fromfile

* cleanup

* cleanup
parent fa1aa52d
import sys
import numpy as np
import pytest
import torch
from datasets_utils import make_fake_flo_file
from torchvision.datasets._optical_flow import _read_flo as read_flo_ref
from torchvision.prototype.datasets.utils._internal import read_flo, fromfile
@pytest.mark.filterwarnings("error:The given NumPy array is not writeable:UserWarning")
@pytest.mark.parametrize(
("np_dtype", "torch_dtype", "byte_order"),
[
(">f4", torch.float32, "big"),
("<f8", torch.float64, "little"),
("<i4", torch.int32, "little"),
(">i8", torch.int64, "big"),
("|u1", torch.uint8, sys.byteorder),
],
)
@pytest.mark.parametrize("count", (-1, 2))
@pytest.mark.parametrize("mode", ("rb", "r+b"))
def test_fromfile(tmpdir, np_dtype, torch_dtype, byte_order, count, mode):
path = tmpdir / "data.bin"
rng = np.random.RandomState(0)
rng.randn(5 if count == -1 else count + 1).astype(np_dtype).tofile(path)
for count_ in (-1, count // 2):
expected = torch.from_numpy(np.fromfile(path, dtype=np_dtype, count=count_).astype(np_dtype[1:]))
with open(path, mode) as file:
actual = fromfile(file, dtype=torch_dtype, byte_order=byte_order, count=count_)
torch.testing.assert_close(actual, expected)
def test_read_flo(tmpdir):
path = tmpdir / "test.flo"
make_fake_flo_file(3, 4, path)
with open(path, "rb") as file:
actual = read_flo(file)
expected = torch.from_numpy(read_flo_ref(path).astype("f4", copy=False))
torch.testing.assert_close(actual, expected)
import abc
import codecs
import functools
import io
import operator
import pathlib
import string
import sys
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast, BinaryIO
import torch
from torchdata.datapipes.iter import (
......@@ -30,6 +28,7 @@ from torchvision.prototype.datasets.utils._internal import (
image_buffer_from_array,
Decompressor,
INFINITE_BUFFER_SIZE,
fromfile,
)
from torchvision.prototype.features import Image, Label
......@@ -50,50 +49,33 @@ class MNISTFileReader(IterDataPipe[torch.Tensor]):
}
def __init__(
self, datapipe: IterDataPipe[Tuple[Any, io.IOBase]], *, start: Optional[int], stop: Optional[int]
self, datapipe: IterDataPipe[Tuple[Any, BinaryIO]], *, start: Optional[int], stop: Optional[int]
) -> None:
self.datapipe = datapipe
self.start = start
self.stop = stop
@staticmethod
def _decode(input: bytes) -> int:
return int(codecs.encode(input, "hex"), 16)
@staticmethod
def _to_tensor(chunk: bytes, *, dtype: torch.dtype, shape: List[int], reverse_bytes: bool) -> torch.Tensor:
# As is, the chunk is not writeable, because it is read from a file and not from memory. Thus, we copy here to
# avoid the warning that torch.frombuffer would emit otherwise. This also enables inplace operations on the
# contents, which would otherwise fail.
chunk = bytearray(chunk)
if reverse_bytes:
chunk.reverse()
tensor = torch.frombuffer(chunk, dtype=dtype).flip(0)
else:
tensor = torch.frombuffer(chunk, dtype=dtype)
return tensor.reshape(shape)
def __iter__(self) -> Iterator[torch.Tensor]:
for _, file in self.datapipe:
magic = self._decode(file.read(4))
read = functools.partial(fromfile, file, byte_order="big")
magic = int(read(dtype=torch.int32, count=1))
dtype = self._DTYPE_MAP[magic // 256]
ndim = magic % 256 - 1
num_samples = self._decode(file.read(4))
shape = [self._decode(file.read(4)) for _ in range(ndim)]
num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
# The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
# we need to reverse the bytes before we can read them with torch.frombuffer().
reverse_bytes = sys.byteorder == "little" and num_bytes_per_value > 1
chunk_size = (cast(int, prod(shape)) if shape else 1) * num_bytes_per_value
num_samples = int(read(dtype=torch.int32, count=1))
shape = cast(List[int], read(dtype=torch.int32, count=ndim).tolist()) if ndim else []
count = prod(shape) if shape else 1
start = self.start or 0
stop = min(self.stop, num_samples) if self.stop else num_samples
file.seek(start * chunk_size, 1)
if start:
num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
file.seek(num_bytes_per_value * count * start, 1)
for _ in range(stop - start):
yield self._to_tensor(file.read(chunk_size), dtype=dtype, shape=shape, reverse_bytes=reverse_bytes)
yield read(dtype=dtype, count=count).reshape(shape)
class _MNISTBase(Dataset):
......
......@@ -3,10 +3,12 @@ import functools
import gzip
import io
import lzma
import mmap
import os
import os.path
import pathlib
import pickle
from typing import BinaryIO
from typing import (
Sequence,
Callable,
......@@ -24,6 +26,7 @@ from typing import cast
import numpy as np
import PIL.Image
import torch
import torch.distributed as dist
import torch.utils.data
from torch.utils.data import IterDataPipe
......@@ -43,6 +46,8 @@ __all__ = [
"path_accessor",
"path_comparator",
"Decompressor",
"fromfile",
"read_flo",
]
K = TypeVar("K")
......@@ -253,3 +258,66 @@ def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe:
# dp = dp.cycle(2)
dp = TakerDataPipe(dp, dataset_size)
return dp
def fromfile(
file: BinaryIO,
*,
dtype: torch.dtype,
byte_order: str,
count: int = -1,
) -> torch.Tensor:
"""Construct a tensor from a binary file.
.. note::
This function is similar to :func:`numpy.fromfile` with two notable differences:
1. This function only accepts an open binary file, but not a path to it.
2. This function has an additional ``byte_order`` parameter, since PyTorch's ``dtype``'s do not support that
concept.
.. note::
If the ``file`` was opened in update mode, i.e. "r+b" or "w+b", reading data is much faster. Be aware that as
long as the file is still open, inplace operations on the returned tensor will reflect back to the file.
Args:
file (IO): Open binary file.
dtype (torch.dtype): Data type of the underlying data as well as of the returned tensor.
byte_order (str): Byte order of the data. Can be "little" or "big" endian.
count (int): Number of values of the returned tensor. If ``-1`` (default), will read the complete file.
"""
byte_order = "<" if byte_order == "little" else ">"
char = "f" if dtype.is_floating_point else ("i" if dtype.is_signed else "u")
item_size = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
np_dtype = byte_order + char + str(item_size)
# PyTorch does not support tensors with underlying read-only memory. In case
# - the file has a .fileno(),
# - the file was opened for updating, i.e. 'r+b' or 'w+b',
# - the file is seekable
# we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it to
# a mutable location afterwards.
buffer: Union[memoryview, bytearray]
try:
buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :]
# Reading from the memoryview does not advance the file cursor, so we have to do it manually.
file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR))
except (PermissionError, io.UnsupportedOperation):
# A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable
buffer = bytearray(file.read(-1 if count == -1 else count * item_size))
# We cannot use torch.frombuffer() directly, since it only supports the native byte order of the system. Thus, we
# read the data with np.frombuffer() with the correct byte order and convert it to the native one with the
# successive .astype() call.
return torch.from_numpy(np.frombuffer(buffer, dtype=np_dtype, count=count).astype(np_dtype[1:], copy=False))
def read_flo(file: BinaryIO) -> torch.Tensor:
if file.read(4) != b"PIEH":
raise ValueError("Magic number incorrect. Invalid .flo file")
width, height = fromfile(file, dtype=torch.int32, byte_order="little", count=2)
flow = fromfile(file, dtype=torch.float32, byte_order="little", count=height * width * 2)
return flow.reshape((height, width, 2)).permute((2, 0, 1))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment