"git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "2e7d799f9cc5b8544c9ef07330c4cd7eacd894ee"
Unverified Commit 29aa38ae authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

replace np.frombuffer with torch.frombuffer in MNIST prototype (#4651)

* replace np.frombuffer with torch.frombuffer in MNIST prototype

* cleanup

* appease mypy

* more cleanup

* clarify inplace offset

* fix num bytes for floating point data
parent 979ecaca
...@@ -5,9 +5,9 @@ import io ...@@ -5,9 +5,9 @@ import io
import operator import operator
import pathlib import pathlib
import string import string
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast, Union import sys
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast
import numpy as np
import torch import torch
from torchdata.datapipes.iter import ( from torchdata.datapipes.iter import (
IterDataPipe, IterDataPipe,
...@@ -38,14 +38,14 @@ __all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"] ...@@ -38,14 +38,14 @@ __all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"]
prod = functools.partial(functools.reduce, operator.mul) prod = functools.partial(functools.reduce, operator.mul)
class MNISTFileReader(IterDataPipe[np.ndarray]): class MNISTFileReader(IterDataPipe[torch.Tensor]):
_DTYPE_MAP = { _DTYPE_MAP = {
8: "u1", # uint8 8: torch.uint8,
9: "i1", # int8 9: torch.int8,
11: "i2", # int16 11: torch.int16,
12: "i4", # int32 12: torch.int32,
13: "f4", # float32 13: torch.float32,
14: "f8", # float64 14: torch.float64,
} }
def __init__( def __init__(
...@@ -59,18 +59,20 @@ class MNISTFileReader(IterDataPipe[np.ndarray]): ...@@ -59,18 +59,20 @@ class MNISTFileReader(IterDataPipe[np.ndarray]):
def _decode(bytes: bytes) -> int: def _decode(bytes: bytes) -> int:
return int(codecs.encode(bytes, "hex"), 16) return int(codecs.encode(bytes, "hex"), 16)
def __iter__(self) -> Iterator[np.ndarray]: def __iter__(self) -> Iterator[torch.Tensor]:
for _, file in self.datapipe: for _, file in self.datapipe:
magic = self._decode(file.read(4)) magic = self._decode(file.read(4))
dtype_type = self._DTYPE_MAP[magic // 256] dtype = self._DTYPE_MAP[magic // 256]
ndim = magic % 256 - 1 ndim = magic % 256 - 1
num_samples = self._decode(file.read(4)) num_samples = self._decode(file.read(4))
shape = [self._decode(file.read(4)) for _ in range(ndim)] shape = [self._decode(file.read(4)) for _ in range(ndim)]
in_dtype = np.dtype(f">{dtype_type}") num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
out_dtype = np.dtype(dtype_type) # The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
chunk_size = (cast(int, prod(shape)) if shape else 1) * in_dtype.itemsize # we need to reverse the bytes before we can read them with torch.frombuffer().
needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1
chunk_size = (cast(int, prod(shape)) if shape else 1) * num_bytes_per_value
start = self.start or 0 start = self.start or 0
stop = self.stop or num_samples stop = self.stop or num_samples
...@@ -78,11 +80,15 @@ class MNISTFileReader(IterDataPipe[np.ndarray]): ...@@ -78,11 +80,15 @@ class MNISTFileReader(IterDataPipe[np.ndarray]):
file.seek(start * chunk_size, 1) file.seek(start * chunk_size, 1)
for _ in range(stop - start): for _ in range(stop - start):
chunk = file.read(chunk_size) chunk = file.read(chunk_size)
yield np.frombuffer(chunk, dtype=in_dtype).astype(out_dtype).reshape(shape) if not needs_byte_reversal:
yield torch.frombuffer(chunk, dtype=dtype).reshape(shape)
chunk = bytearray(chunk)
chunk.reverse()
yield torch.frombuffer(chunk, dtype=dtype).flip(0).reshape(shape)
class _MNISTBase(Dataset): class _MNISTBase(Dataset):
_FORMAT = "png"
_URL_BASE: str _URL_BASE: str
@abc.abstractmethod @abc.abstractmethod
...@@ -105,24 +111,23 @@ class _MNISTBase(Dataset): ...@@ -105,24 +111,23 @@ class _MNISTBase(Dataset):
def _collate_and_decode( def _collate_and_decode(
self, self,
data: Tuple[np.ndarray, np.ndarray], data: Tuple[torch.Tensor, torch.Tensor],
*, *,
config: DatasetConfig, config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
image_array, label_array = data image, label = data
image: Union[torch.Tensor, io.BytesIO]
if decoder is raw: if decoder is raw:
image = torch.from_numpy(image_array) image = image.unsqueeze(0)
else: else:
image_buffer = image_buffer_from_array(image_array) image_buffer = image_buffer_from_array(image.numpy())
image = decoder(image_buffer) if decoder else image_buffer image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]
label = torch.tensor(label_array, dtype=torch.int64)
category = self.info.categories[int(label)] category = self.info.categories[int(label)]
label = label.to(torch.int64)
return dict(image=image, label=label, category=category) return dict(image=image, category=category, label=label)
def _make_datapipe( def _make_datapipe(
self, self,
...@@ -293,12 +298,11 @@ class EMNIST(_MNISTBase): ...@@ -293,12 +298,11 @@ class EMNIST(_MNISTBase):
def _collate_and_decode( def _collate_and_decode(
self, self,
data: Tuple[np.ndarray, np.ndarray], data: Tuple[torch.Tensor, torch.Tensor],
*, *,
config: DatasetConfig, config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
image_array, label_array = data
# In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper). # In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper).
# That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense, # That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense,
# i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create this gaps. For example, # i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create this gaps. For example,
...@@ -308,8 +312,8 @@ class EMNIST(_MNISTBase): ...@@ -308,8 +312,8 @@ class EMNIST(_MNISTBase):
# index 39 (10 digits + 26 uppercase letters + 4th lower case letter - 1 for zero indexing) # index 39 (10 digits + 26 uppercase letters + 4th lower case letter - 1 for zero indexing)
# in self.categories. Thus, we need to add 1 to the label to correct this. # in self.categories. Thus, we need to add 1 to the label to correct this.
if config.image_set in ("Balanced", "By_Merge"): if config.image_set in ("Balanced", "By_Merge"):
label_array += np.array(self._LABEL_OFFSETS.get(int(label_array), 0), dtype=label_array.dtype) data[1] += self._LABEL_OFFSETS.get(int(data[1]), 0)
return super()._collate_and_decode((image_array, label_array), config=config, decoder=decoder) return super()._collate_and_decode(data, config=config, decoder=decoder)
def _make_datapipe( def _make_datapipe(
self, self,
...@@ -379,22 +383,22 @@ class QMNIST(_MNISTBase): ...@@ -379,22 +383,22 @@ class QMNIST(_MNISTBase):
def _collate_and_decode( def _collate_and_decode(
self, self,
data: Tuple[np.ndarray, np.ndarray], data: Tuple[torch.Tensor, torch.Tensor],
*, *,
config: DatasetConfig, config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
image_array, label_array = data image, ann = data
label_parts = label_array.tolist() label, *extra_anns = ann
sample = super()._collate_and_decode((image_array, label_parts[0]), config=config, decoder=decoder) sample = super()._collate_and_decode((image, label), config=config, decoder=decoder)
sample.update( sample.update(
dict( dict(
zip( zip(
("nist_hsf_series", "nist_writer_id", "digit_index", "nist_label", "global_digit_index"), ("nist_hsf_series", "nist_writer_id", "digit_index", "nist_label", "global_digit_index"),
label_parts[1:6], [int(value) for value in extra_anns[:5]],
) )
) )
) )
sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in label_parts[-2:]]))) sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in extra_anns[-2:]])))
return sample return sample
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment