Unverified Commit 29aa38ae authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

replace np.frombuffer with torch.frombuffer in MNIST prototype (#4651)

* replace np.frombuffer with torch.frombuffer in MNIST prototype

* cleanup

* appease mypy

* more cleanup

* clarify inplace offset

* fix num bytes for floating point data
parent 979ecaca
......@@ -5,9 +5,9 @@ import io
import operator
import pathlib
import string
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast, Union
import sys
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast
import numpy as np
import torch
from torchdata.datapipes.iter import (
IterDataPipe,
......@@ -38,14 +38,14 @@ __all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"]
prod = functools.partial(functools.reduce, operator.mul)
class MNISTFileReader(IterDataPipe[np.ndarray]):
class MNISTFileReader(IterDataPipe[torch.Tensor]):
_DTYPE_MAP = {
8: "u1", # uint8
9: "i1", # int8
11: "i2", # int16
12: "i4", # int32
13: "f4", # float32
14: "f8", # float64
8: torch.uint8,
9: torch.int8,
11: torch.int16,
12: torch.int32,
13: torch.float32,
14: torch.float64,
}
def __init__(
......@@ -59,18 +59,20 @@ class MNISTFileReader(IterDataPipe[np.ndarray]):
def _decode(bytes: bytes) -> int:
return int(codecs.encode(bytes, "hex"), 16)
def __iter__(self) -> Iterator[np.ndarray]:
def __iter__(self) -> Iterator[torch.Tensor]:
for _, file in self.datapipe:
magic = self._decode(file.read(4))
dtype_type = self._DTYPE_MAP[magic // 256]
dtype = self._DTYPE_MAP[magic // 256]
ndim = magic % 256 - 1
num_samples = self._decode(file.read(4))
shape = [self._decode(file.read(4)) for _ in range(ndim)]
in_dtype = np.dtype(f">{dtype_type}")
out_dtype = np.dtype(dtype_type)
chunk_size = (cast(int, prod(shape)) if shape else 1) * in_dtype.itemsize
num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
# The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
# we need to reverse the bytes before we can read them with torch.frombuffer().
needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1
chunk_size = (cast(int, prod(shape)) if shape else 1) * num_bytes_per_value
start = self.start or 0
stop = self.stop or num_samples
......@@ -78,11 +80,15 @@ class MNISTFileReader(IterDataPipe[np.ndarray]):
file.seek(start * chunk_size, 1)
for _ in range(stop - start):
chunk = file.read(chunk_size)
yield np.frombuffer(chunk, dtype=in_dtype).astype(out_dtype).reshape(shape)
if not needs_byte_reversal:
yield torch.frombuffer(chunk, dtype=dtype).reshape(shape)
chunk = bytearray(chunk)
chunk.reverse()
yield torch.frombuffer(chunk, dtype=dtype).flip(0).reshape(shape)
class _MNISTBase(Dataset):
_FORMAT = "png"
_URL_BASE: str
@abc.abstractmethod
......@@ -105,24 +111,23 @@ class _MNISTBase(Dataset):
def _collate_and_decode(
self,
data: Tuple[np.ndarray, np.ndarray],
data: Tuple[torch.Tensor, torch.Tensor],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
image_array, label_array = data
image, label = data
image: Union[torch.Tensor, io.BytesIO]
if decoder is raw:
image = torch.from_numpy(image_array)
image = image.unsqueeze(0)
else:
image_buffer = image_buffer_from_array(image_array)
image = decoder(image_buffer) if decoder else image_buffer
image_buffer = image_buffer_from_array(image.numpy())
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]
label = torch.tensor(label_array, dtype=torch.int64)
category = self.info.categories[int(label)]
label = label.to(torch.int64)
return dict(image=image, label=label, category=category)
return dict(image=image, category=category, label=label)
def _make_datapipe(
self,
......@@ -293,12 +298,11 @@ class EMNIST(_MNISTBase):
def _collate_and_decode(
self,
data: Tuple[np.ndarray, np.ndarray],
data: Tuple[torch.Tensor, torch.Tensor],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
image_array, label_array = data
# In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper).
# That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense,
# i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create this gaps. For example,
......@@ -308,8 +312,8 @@ class EMNIST(_MNISTBase):
# index 39 (10 digits + 26 uppercase letters + 4th lower case letter - 1 for zero indexing)
# in self.categories. Thus, we need to add 1 to the label to correct this.
if config.image_set in ("Balanced", "By_Merge"):
label_array += np.array(self._LABEL_OFFSETS.get(int(label_array), 0), dtype=label_array.dtype)
return super()._collate_and_decode((image_array, label_array), config=config, decoder=decoder)
data[1] += self._LABEL_OFFSETS.get(int(data[1]), 0)
return super()._collate_and_decode(data, config=config, decoder=decoder)
def _make_datapipe(
self,
......@@ -379,22 +383,22 @@ class QMNIST(_MNISTBase):
def _collate_and_decode(
self,
data: Tuple[np.ndarray, np.ndarray],
data: Tuple[torch.Tensor, torch.Tensor],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
image_array, label_array = data
label_parts = label_array.tolist()
sample = super()._collate_and_decode((image_array, label_parts[0]), config=config, decoder=decoder)
image, ann = data
label, *extra_anns = ann
sample = super()._collate_and_decode((image, label), config=config, decoder=decoder)
sample.update(
dict(
zip(
("nist_hsf_series", "nist_writer_id", "digit_index", "nist_label", "global_digit_index"),
label_parts[1:6],
[int(value) for value in extra_anns[:5]],
)
)
)
sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in label_parts[-2:]])))
sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in extra_anns[-2:]])))
return sample
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment