Unverified Commit dcf5dc87 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

fix `fromfile` on windows (#4980)

* exclude windows from mmap

* add comment for windows

* appease mypy
parent ba299e8f
...@@ -8,6 +8,7 @@ import os ...@@ -8,6 +8,7 @@ import os
import os.path import os.path
import pathlib import pathlib
import pickle import pickle
import platform
from typing import BinaryIO from typing import BinaryIO
from typing import ( from typing import (
Sequence, Sequence,
...@@ -260,6 +261,11 @@ def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe: ...@@ -260,6 +261,11 @@ def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe:
return dp return dp
def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray:
# A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable
return bytearray(file.read(-1 if count == -1 else count * item_size))
def fromfile( def fromfile(
file: BinaryIO, file: BinaryIO,
*, *,
...@@ -293,20 +299,24 @@ def fromfile( ...@@ -293,20 +299,24 @@ def fromfile(
item_size = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8 item_size = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
np_dtype = byte_order + char + str(item_size) np_dtype = byte_order + char + str(item_size)
buffer: Union[memoryview, bytearray]
if platform.system() != "Windows":
# PyTorch does not support tensors with underlying read-only memory. In case # PyTorch does not support tensors with underlying read-only memory. In case
# - the file has a .fileno(), # - the file has a .fileno(),
# - the file was opened for updating, i.e. 'r+b' or 'w+b', # - the file was opened for updating, i.e. 'r+b' or 'w+b',
# - the file is seekable # - the file is seekable
# we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it to # we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it
# a mutable location afterwards. # to a mutable location afterwards.
buffer: Union[memoryview, bytearray]
try: try:
buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :] buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :]
# Reading from the memoryview does not advance the file cursor, so we have to do it manually. # Reading from the memoryview does not advance the file cursor, so we have to do it manually.
file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR)) file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR))
except (PermissionError, io.UnsupportedOperation): except (PermissionError, io.UnsupportedOperation):
# A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable buffer = _read_mutable_buffer_fallback(file, count, item_size)
buffer = bytearray(file.read(-1 if count == -1 else count * item_size)) else:
# On Windows just trying to call mmap.mmap() on a file that does not support it, may corrupt the internal state
# so no data can be read afterwards. Thus, we simply ignore the possible speed-up.
buffer = _read_mutable_buffer_fallback(file, count, item_size)
# We cannot use torch.frombuffer() directly, since it only supports the native byte order of the system. Thus, we # We cannot use torch.frombuffer() directly, since it only supports the native byte order of the system. Thus, we
# read the data with np.frombuffer() with the correct byte order and convert it to the native one with the # read the data with np.frombuffer() with the correct byte order and convert it to the native one with the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment