".github/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "709908a76b3503084acdcf9d26f290e60cc4db06"
Unverified Commit b50ffef5 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

[FBcode->GH] Enabling sharded datasets (#4776) (#4796)

Summary:
This diff enables ImageNet to be usable from within torchvision.

It implements a different codepath for some families of datasets that can be sharded into multiple pieces, and uses it if within fbcode.

For now all previously supported datasets are also supported (MNIST, CIFAR and FashionMNIST), in addition to ImageNet.

Pull Request resolved: https://github.com/pytorch/vision/pull/4776

Reviewed By: datumbox

Differential Revision: D31929684

fbshipit-source-id: 58f9fe2ed6ce731a6b43a38be912983247eff562
(cherry picked from commit b43fead8ea5164e37637ade0e80477e94f993365)
parent 452ff86c
import importlib.machinery import importlib.machinery
import os import os
from torch.hub import _get_torch_home
_HOME = os.path.join(_get_torch_home(), "datasets", "vision")
_USE_SHARDED_DATASETS = False
def _download_file_from_remote_location(fpath: str, url: str) -> None: def _download_file_from_remote_location(fpath: str, url: str) -> None:
pass pass
......
import io import io
import os
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
...@@ -63,12 +64,13 @@ def load( ...@@ -63,12 +64,13 @@ def load(
split: str = "train", split: str = "train",
**options: Any, **options: Any,
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
name = name.lower()
dataset = find(name) dataset = find(name)
if decoder is default: if decoder is default:
decoder = DEFAULT_DECODER.get(dataset.info.type) decoder = DEFAULT_DECODER.get(dataset.info.type)
config = dataset.info.make_config(split=split, **options) config = dataset.info.make_config(split=split, **options)
root = home() / name root = os.path.join(home(), name)
return dataset.to_datapipe(root, config=config, decoder=decoder) return dataset.to_datapipe(root, config=config, decoder=decoder)
...@@ -38,9 +38,13 @@ class ImageNet(Dataset): ...@@ -38,9 +38,13 @@ class ImageNet(Dataset):
extra=dict( extra=dict(
wnid_to_category=FrozenMapping(zip(wnids, categories)), wnid_to_category=FrozenMapping(zip(wnids, categories)),
category_to_wnid=FrozenMapping(zip(categories, wnids)), category_to_wnid=FrozenMapping(zip(categories, wnids)),
sizes=FrozenMapping([(DatasetConfig(split="train"), 1281167), (DatasetConfig(split="val"), 50000)]),
), ),
) )
def supports_sharded(self) -> bool:
return True
@property @property
def category_to_wnid(self) -> Dict[str, str]: def category_to_wnid(self) -> Dict[str, str]:
return cast(Dict[str, str], self.info.extra.category_to_wnid) return cast(Dict[str, str], self.info.extra.category_to_wnid)
......
import os import os
import pathlib from typing import Optional
from typing import Optional, Union
from torch.hub import _get_torch_home import torchvision._internally_replaced_utils as _iru
HOME = pathlib.Path(_get_torch_home()) / "datasets" / "vision"
def home(root: Optional[str] = None) -> str:
def home(root: Optional[Union[str, pathlib.Path]] = None) -> pathlib.Path:
global HOME
if root is not None: if root is not None:
HOME = pathlib.Path(root).expanduser().resolve() _iru._HOME = root
return HOME return _iru._HOME
root = os.getenv("TORCHVISION_DATASETS_HOME") root = os.getenv("TORCHVISION_DATASETS_HOME")
if root is not None: if root is not None:
return pathlib.Path(root) return root
return _iru._HOME
def use_sharded_dataset(use: Optional[bool] = None) -> bool:
if use is not None:
_iru._USE_SHARDED_DATASETS = use
return _iru._USE_SHARDED_DATASETS
use = os.getenv("TORCHVISION_SHARDED_DATASETS")
if use is not None:
return use == "1"
return HOME return _iru._USE_SHARDED_DATASETS
...@@ -2,6 +2,7 @@ import abc ...@@ -2,6 +2,7 @@ import abc
import csv import csv
import enum import enum
import io import io
import os
import pathlib import pathlib
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, Tuple from typing import Any, Callable, Dict, List, Optional, Sequence, Union, Tuple
...@@ -12,7 +13,8 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -12,7 +13,8 @@ from torchvision.prototype.datasets.utils._internal import (
sequence_to_str, sequence_to_str,
) )
from ._internal import FrozenBunch, make_repr, BUILTIN_DIR from .._home import use_sharded_dataset
from ._internal import FrozenBunch, make_repr, BUILTIN_DIR, _make_sharded_datapipe
from ._resource import OnlineResource from ._resource import OnlineResource
...@@ -150,6 +152,9 @@ class Dataset(abc.ABC): ...@@ -150,6 +152,9 @@ class Dataset(abc.ABC):
) -> IterDataPipe[Dict[str, Any]]: ) -> IterDataPipe[Dict[str, Any]]:
pass pass
def supports_sharded(self) -> bool:
return False
def to_datapipe( def to_datapipe(
self, self,
root: Union[str, pathlib.Path], root: Union[str, pathlib.Path],
...@@ -160,6 +165,10 @@ class Dataset(abc.ABC): ...@@ -160,6 +165,10 @@ class Dataset(abc.ABC):
if not config: if not config:
config = self.info.default_config config = self.info.default_config
if use_sharded_dataset() and self.supports_sharded():
root = os.path.join(root, *config.values())
dataset_size = self.info.extra["sizes"][config]
return _make_sharded_datapipe(root, dataset_size)
resource_dps = [resource.to_datapipe(root) for resource in self.resources(config)] resource_dps = [resource.to_datapipe(root) for resource in self.resources(config)]
return self._make_datapipe(resource_dps, config=config, decoder=decoder) return self._make_datapipe(resource_dps, config=config, decoder=decoder)
......
...@@ -8,6 +8,7 @@ import lzma ...@@ -8,6 +8,7 @@ import lzma
import os import os
import os.path import os.path
import pathlib import pathlib
import pickle
import textwrap import textwrap
from typing import ( from typing import (
Collection, Collection,
...@@ -21,14 +22,20 @@ from typing import ( ...@@ -21,14 +22,20 @@ from typing import (
Dict, Dict,
Optional, Optional,
NoReturn, NoReturn,
IO,
Iterable, Iterable,
Mapping, Mapping,
Sized,
) )
from typing import cast from typing import cast
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import torch.distributed as dist
import torch.utils.data
from torch.utils.data import IterDataPipe from torch.utils.data import IterDataPipe
from torchdata.datapipes.iter import IoPathFileLister, IoPathFileLoader
__all__ = [ __all__ = [
"INFINITE_BUFFER_SIZE", "INFINITE_BUFFER_SIZE",
...@@ -277,3 +284,79 @@ class Decompressor(IterDataPipe[Tuple[str, io.IOBase]]): ...@@ -277,3 +284,79 @@ class Decompressor(IterDataPipe[Tuple[str, io.IOBase]]):
type = self._detect_compression_type(path) type = self._detect_compression_type(path)
decompressor = self._DECOMPRESSORS[type] decompressor = self._DECOMPRESSORS[type]
yield path, decompressor(file) yield path, decompressor(file)
class PicklerDataPipe(IterDataPipe):
def __init__(self, source_datapipe: IterDataPipe[Tuple[str, IO[bytes]]]) -> None:
self.source_datapipe = source_datapipe
def __iter__(self) -> Iterator[Any]:
for _, fobj in self.source_datapipe:
data = pickle.load(fobj)
for _, d in enumerate(data):
yield d
class SharderDataPipe(torch.utils.data.datapipes.iter.grouping.ShardingFilterIterDataPipe):
def __init__(self, source_datapipe: IterDataPipe) -> None:
super().__init__(source_datapipe)
self.rank = 0
self.world_size = 1
if dist.is_available() and dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.apply_sharding(self.world_size, self.rank)
def __iter__(self) -> Iterator[Any]:
num_workers = self.world_size
worker_id = self.rank
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
worker_id = worker_id + worker_info.id * num_workers
num_workers *= worker_info.num_workers
self.apply_sharding(num_workers, worker_id)
yield from super().__iter__()
class TakerDataPipe(IterDataPipe):
def __init__(self, source_datapipe: IterDataPipe, num_take: int) -> None:
super().__init__()
self.source_datapipe = source_datapipe
self.num_take = num_take
self.world_size = 1
if dist.is_available() and dist.is_initialized():
self.world_size = dist.get_world_size()
def __iter__(self) -> Iterator[Any]:
num_workers = self.world_size
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
num_workers *= worker_info.num_workers
# TODO: this is weird as it drops more elements than it should
num_take = self.num_take // num_workers
for i, data in enumerate(self.source_datapipe):
if i < num_take:
yield data
else:
break
def __len__(self) -> int:
num_take = self.num_take // self.world_size
if isinstance(self.source_datapipe, Sized):
if len(self.source_datapipe) < num_take:
num_take = len(self.source_datapipe)
# TODO: might be weird to not take `num_workers` into account
return num_take
def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe:
dp = IoPathFileLister(root=root)
dp = SharderDataPipe(dp)
dp = dp.shuffle(buffer_size=INFINITE_BUFFER_SIZE)
dp = IoPathFileLoader(dp, mode="rb")
dp = PicklerDataPipe(dp)
# dp = dp.cycle(2)
dp = TakerDataPipe(dp, dataset_size)
return dp
...@@ -4,7 +4,8 @@ from typing import Optional, Union ...@@ -4,7 +4,8 @@ from typing import Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
from torch.utils.data import IterDataPipe from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.iter import FileLoader, IterableWrapper from torch.utils.data.datapipes.iter import IterableWrapper
from torchdata.datapipes.iter import IoPathFileLoader
# FIXME # FIXME
...@@ -19,7 +20,7 @@ class LocalResource: ...@@ -19,7 +20,7 @@ class LocalResource:
self.sha256 = sha256 or compute_sha256(self.path) self.sha256 = sha256 or compute_sha256(self.path)
def to_datapipe(self) -> IterDataPipe: def to_datapipe(self) -> IterDataPipe:
return FileLoader(IterableWrapper((str(self.path),))) return IoPathFileLoader(IterableWrapper((str(self.path),)), mode="rb") # type: ignore
class OnlineResource: class OnlineResource:
...@@ -29,9 +30,9 @@ class OnlineResource: ...@@ -29,9 +30,9 @@ class OnlineResource:
self.file_name = file_name self.file_name = file_name
def to_datapipe(self, root: Union[str, pathlib.Path]) -> IterDataPipe: def to_datapipe(self, root: Union[str, pathlib.Path]) -> IterDataPipe:
path = (pathlib.Path(root) / self.file_name).expanduser().resolve() path = os.path.join(root, self.file_name)
# FIXME # FIXME
return FileLoader(IterableWrapper((str(path),))) return IoPathFileLoader(IterableWrapper((str(path),)), mode="rb") # type: ignore
# TODO: add support for mirrors # TODO: add support for mirrors
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment