"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "ff91f154eedfa6d27d9c6786b422a613f44e9edd"
Unverified Commit d605d7d4 authored by Sergii Khomenko's avatar Sergii Khomenko Committed by GitHub
Browse files

Migrate mnist dataset from np.frombuffer (#4598)



* Migrate mnist dataset from np.frombuffer

* Add a copy with bytearray for non-writable buffers

* Add byte reversal for mnist
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 4ba91bff
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
import os.path import os.path
import shutil import shutil
import string import string
import sys
import warnings import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
from urllib.error import URLError from urllib.error import URLError
...@@ -489,12 +490,12 @@ def get_int(b: bytes) -> int: ...@@ -489,12 +490,12 @@ def get_int(b: bytes) -> int:
SN3_PASCALVINCENT_TYPEMAP = { SN3_PASCALVINCENT_TYPEMAP = {
8: (torch.uint8, np.uint8, np.uint8), 8: torch.uint8,
9: (torch.int8, np.int8, np.int8), 9: torch.int8,
11: (torch.int16, np.dtype(">i2"), "i2"), 11: torch.int16,
12: (torch.int32, np.dtype(">i4"), "i4"), 12: torch.int32,
13: (torch.float32, np.dtype(">f4"), "f4"), 13: torch.float32,
14: (torch.float64, np.dtype(">f8"), "f8"), 14: torch.float64,
} }
...@@ -511,11 +512,19 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso ...@@ -511,11 +512,19 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso
ty = magic // 256 ty = magic // 256
assert 1 <= nd <= 3 assert 1 <= nd <= 3
assert 8 <= ty <= 14 assert 8 <= ty <= 14
m = SN3_PASCALVINCENT_TYPEMAP[ty] torch_type = SN3_PASCALVINCENT_TYPEMAP[ty]
s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)] s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
num_bytes_per_value = torch.iinfo(torch_type).bits // 8
# The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
# we need to reverse the bytes before we can read them with torch.frombuffer().
needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1
parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1)))
if needs_byte_reversal:
parsed = parsed.flip(0)
assert parsed.shape[0] == np.prod(s) or not strict assert parsed.shape[0] == np.prod(s) or not strict
return torch.from_numpy(parsed.astype(m[2])).view(*s) return parsed.view(*s)
def read_label_file(path: str) -> torch.Tensor: def read_label_file(path: str) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment