Unverified Commit d605d7d4 authored by Sergii Khomenko's avatar Sergii Khomenko Committed by GitHub
Browse files

Migrate mnist dataset from np.frombuffer (#4598)



* Migrate mnist dataset from np.frombuffer

* Add a copy with bytearray for non-writable buffers

* Add byte reversal for mnist
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 4ba91bff
......@@ -3,6 +3,7 @@ import os
import os.path
import shutil
import string
import sys
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple
from urllib.error import URLError
......@@ -489,12 +490,12 @@ def get_int(b: bytes) -> int:
SN3_PASCALVINCENT_TYPEMAP = {
8: (torch.uint8, np.uint8, np.uint8),
9: (torch.int8, np.int8, np.int8),
11: (torch.int16, np.dtype(">i2"), "i2"),
12: (torch.int32, np.dtype(">i4"), "i4"),
13: (torch.float32, np.dtype(">f4"), "f4"),
14: (torch.float64, np.dtype(">f8"), "f8"),
8: torch.uint8,
9: torch.int8,
11: torch.int16,
12: torch.int32,
13: torch.float32,
14: torch.float64,
}
......@@ -511,11 +512,19 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso
ty = magic // 256
assert 1 <= nd <= 3
assert 8 <= ty <= 14
m = SN3_PASCALVINCENT_TYPEMAP[ty]
torch_type = SN3_PASCALVINCENT_TYPEMAP[ty]
s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
num_bytes_per_value = torch.iinfo(torch_type).bits // 8
# The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
# we need to reverse the bytes before we can read them with torch.frombuffer().
needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1
parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1)))
if needs_byte_reversal:
parsed = parsed.flip(0)
assert parsed.shape[0] == np.prod(s) or not strict
return torch.from_numpy(parsed.astype(m[2])).view(*s)
return parsed.view(*s)
def read_label_file(path: str) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment