Commit 87374efb authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

add option to use disk cache to store underlying dataset

Summary:
# TLDR: To use this feature, setting `D2 (https://github.com/facebookresearch/d2go/commit/7992f91324aee6ae59795063a007c6837e60cdb8)GO_DATA.DATASETS.DISK_CACHE.ENABLED` to `True`.

To support larger datasets, one idea is to offload the DatasetFromList from RAM to disk to avoid OOM. `DiskCachedDatasetFromList` is a drop-in replacement for `DatasetFromList`, during `__init__`, it puts serialized list onto the disk and only stores the mapping in the RAM (the mapping could be represented by a list of addresses or even just a single number, eg. every N item is grouped together and N is the fixed number), then the `__getitem__` reads data from disk and deserializes the element. Some more details:
- Originally the RAM cost is `O(s*G*N)` where `s` is average data size, `G` is #GPUs, `N` is dataset size. When diskcache is enabled, depending on the type of mapping, the final RAM cost is constant or O(N) with a very small coefficient; the final disk cost is `O(s*N)`.
- The RAM usage is peaked at preparing stage, the cost is `O(s*N)`, if this becomes bottleneck, we probably need to think about modifying the data loading function (registered in DatasetCatalog). We also change the data loading function to only run on local master process, otherwise RAM will be peaked at `O(s*G*N)` if all processes are loading data at the same time.
- The time overhead of initialization is linear to dataset size, this is capped by disk I/O speed and performance of diskcache library. Benchmark shows it can at least handle 1GB per minute if writing in chucks (much worse if not), which should be fine in most use cases.
- There're also a bit time overhead when reading the data, but this is usually negligible compared with reading files from external storage like manifold.

It's not very easy to integrate this into D2 (https://github.com/facebookresearch/d2go/commit/7992f91324aee6ae59795063a007c6837e60cdb8)/D2 (https://github.com/facebookresearch/d2go/commit/7992f91324aee6ae59795063a007c6837e60cdb8)Go cleanly without patching the code, several approaches:
- Integrate into D2 (https://github.com/facebookresearch/d2go/commit/7992f91324aee6ae59795063a007c6837e60cdb8) directly (modifying D2 (https://github.com/facebookresearch/d2go/commit/7992f91324aee6ae59795063a007c6837e60cdb8)'s `DatasetFromList` and `get_detection_dataset_dicts`): might be the cleanest way, but D2 (https://github.com/facebookresearch/d2go/commit/7992f91324aee6ae59795063a007c6837e60cdb8) doesn't depend on `diskcache` and this is a bit experimental right now.
- D2 (https://github.com/facebookresearch/d2go/commit/7992f91324aee6ae59795063a007c6837e60cdb8)Go uses its own version of [_train_loader_from_config](https://fburl.com/code/0gig5tj2) that wraps the returned `dataset`. It has two issues: 1): it's hard to make the underlying `get_detection_dataset_dicts` only run on local master, partly because building sampler uses `comm.shared_random_seed()`, things can easily go out-of -sync 2): needs some duplicated code for test loader.
- pass new arguments along the way, it requires touching D2 (https://github.com/facebookresearch/d2go/commit/7992f91324aee6ae59795063a007c6837e60cdb8)'s code as well, and we need to carry new arguments in lot of places.

Lots of TODOs:
- Automatically enable this when dataset is larger than certain threshold (need to figure out how to do this in multiple GPUs, some communication is needed if only local master is reading the dataset).
- better cleanups
- figure out the best way of integrating this (patching is a bit hacky) into D2 (https://github.com/facebookresearch/d2go/commit/7992f91324aee6ae59795063a007c6837e60cdb8)/D2 (https://github.com/facebookresearch/d2go/commit/7992f91324aee6ae59795063a007c6837e60cdb8)Go.
- run more benchmarks
- add unit test (maybe also enable integration tests using 2 nodes 2 GPUs for distributed settings)

Reviewed By: sstsai-adl

Differential Revision: D27451187

fbshipit-source-id: 7d329e1a3c3f9ec1fb9ada0298a52a33f2730e15
parent fb0164c3
...@@ -34,6 +34,10 @@ def add_d2go_data_default_configs(_C): ...@@ -34,6 +34,10 @@ def add_d2go_data_default_configs(_C):
# by specifying the filename (without .py). # by specifying the filename (without .py).
_C.D2GO_DATA.DATASETS.DYNAMIC_DATASETS = [] _C.D2GO_DATA.DATASETS.DYNAMIC_DATASETS = []
# Config for caching the dataset annotations on local disk
_C.D2GO_DATA.DATASETS.DISK_CACHE = CN()
_C.D2GO_DATA.DATASETS.DISK_CACHE.ENABLED = False
# TODO: potentially add this config # TODO: potentially add this config
# # List of extra keys in annotation, the item will be forwarded by # # List of extra keys in annotation, the item will be forwarded by
# # extended_coco_load. # # extended_coco_load.
......
...@@ -7,17 +7,16 @@ import contextlib ...@@ -7,17 +7,16 @@ import contextlib
import json import json
import logging import logging
import os import os
import pickle
import re import re
import shutil import shutil
import tempfile import tempfile
import uuid
from collections import defaultdict from collections import defaultdict
from unittest import mock
import numpy as np import numpy as np
import torch.utils.data as data import torch.utils.data as data
logger = logging.getLogger(__name__)
from d2go.config import temp_defrost from d2go.config import temp_defrost
from d2go.data.datasets import ( from d2go.data.datasets import (
register_dataset_split, register_dataset_split,
...@@ -26,7 +25,14 @@ from d2go.data.datasets import ( ...@@ -26,7 +25,14 @@ from d2go.data.datasets import (
INJECTED_COCO_DATASETS_LUT, INJECTED_COCO_DATASETS_LUT,
) )
from detectron2.data import DatasetCatalog, MetadataCatalog from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.data.build import (
get_detection_dataset_dicts as d2_get_detection_dataset_dicts,
)
from detectron2.utils import comm
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
from detectron2.utils.logger import log_every_n_seconds
logger = logging.getLogger(__name__)
class AdhocDatasetManager: class AdhocDatasetManager:
...@@ -379,3 +385,277 @@ def update_cfg_if_using_adhoc_dataset(cfg): ...@@ -379,3 +385,277 @@ def update_cfg_if_using_adhoc_dataset(cfg):
cfg.DATASETS.TEST = tuple(ds.new_ds_name for ds in new_test_datasets) cfg.DATASETS.TEST = tuple(ds.new_ds_name for ds in new_test_datasets)
return cfg return cfg
def _local_master_gather(func, check_equal=False):
if comm.get_local_rank() == 0:
x = func()
assert x is not None
else:
x = None
x_all = comm.all_gather(x)
x_local_master = [x for x in x_all if x is not None]
if check_equal:
master = x_local_master[0]
assert all(x == master for x in x_local_master), x_local_master
return x_local_master
class DiskCachedDatasetFromList(data.Dataset):
"""
Wrap a list to a torch Dataset, the underlying storage is off-loaded to disk to
save RAM usage.
"""
CACHE_DIR = "/tmp/DatasetFromList_cache"
_OCCUPIED_CACHE_DIRS = set()
def __init__(self, lst, strategy="batched_static"):
"""
Args:
lst (list): a list which contains elements to produce.
strategy (str): strategy of using diskcache, supported strategies:
- native: saving each item individually.
- batched_static: group N items together, where N is calculated from
the average item size.
"""
self._lst = lst
self._diskcache_strategy = strategy
def _serialize(data):
buffer = pickle.dumps(data, protocol=-1)
return np.frombuffer(buffer, dtype=np.uint8)
logger.info(
"Serializing {} elements to byte tensors and concatenating them all ...".format(
len(self._lst)
)
)
self._lst = [_serialize(x) for x in self._lst]
# TODO: only enabling DiskCachedDataset for large enough dataset
logger.info(
"Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024 ** 2)
)
self._initialize_diskcache()
def _initialize_diskcache(self):
from mobile_cv.common.misc.local_cache import LocalCache
cache_dir = "{}/{}".format(
DiskCachedDatasetFromList.CACHE_DIR, uuid.uuid4().hex[:8]
)
cache_dir = comm.all_gather(cache_dir)[0] # use same cache_dir
logger.info("Creating diskcache database in: {}".format(cache_dir))
self._cache = LocalCache(cache_dir=cache_dir, num_shards=8)
# self._cache.cache.clear(retry=True) # seems faster if index exists
if comm.get_local_rank() == 0:
DiskCachedDatasetFromList.get_all_cache_dirs().add(self._cache.cache_dir)
if self._diskcache_strategy == "naive":
for i, item in enumerate(self._lst):
ret = self._write_to_local_db((i, item))
assert ret, "Error writing index {} to local db".format(i)
pct = 100.0 * i / len(self._lst)
self._log_progress(pct)
# NOTE: each item might be small in size (hundreds of bytes),
# writing million of them can take a pretty long time (hours)
# because of frequent disk access. One solution is grouping a batch
# of items into larger blob.
elif self._diskcache_strategy == "batched_static":
TARGET_BYTES = 50 * 1024
average_bytes = np.average(
[
self._lst[int(x)].size
for x in np.linspace(0, len(self._lst) - 1, 1000)
]
)
self._chuck_size = max(1, int(TARGET_BYTES / average_bytes))
logger.info(
"Average data size: {} bytes; target chuck data size {} KiB;"
" {} items per chuck; {} chucks in total".format(
average_bytes,
TARGET_BYTES / 1024,
self._chuck_size,
int(len(self._lst) / self._chuck_size),
)
)
for i in range(0, len(self._lst), self._chuck_size):
chunk = self._lst[i : i + self._chuck_size]
chunk_i = int(i / self._chuck_size)
ret = self._write_to_local_db((chunk_i, chunk))
assert ret, "Error writing index {} to local db".format(chunk_i)
pct = 100.0 * i / len(self._lst)
self._log_progress(pct)
# NOTE: instead of using fixed chuck size, items can be grouped dynamically
elif self._diskcache_strategy == "batched_dynamic":
raise NotImplementedError()
else:
raise NotImplementedError(self._diskcache_strategy)
comm.synchronize()
logger.info(
"Finished writing to local disk, db size: {:.2f} MiB".format(
self._cache.cache.volume() / 1024 ** 2
)
)
# Optional sync for some strategies
if self._diskcache_strategy == "batched_static":
# propagate chuck size and make sure all local rank 0 uses the same value
self._chuck_size = _local_master_gather(
lambda: self._chuck_size, check_equal=True
)[0]
# free the memory of self._lst
self._size = _local_master_gather(lambda: len(self._lst), check_equal=True)[0]
del self._lst
def _write_to_local_db(self, task):
index, record = task
db_path = str(index)
# suc = self._cache.load(lambda path, x: x, db_path, record)
# record = BytesIO(np.random.bytes(np.random.randint(70000, 90000)))
suc = self._cache.cache.set(db_path, record, retry=True)
return suc
def _log_progress(self, percentage):
log_every_n_seconds(
logging.INFO,
"({:.2f}%) Wrote {} elements to local disk cache, db size: {:.2f} MiB".format(
percentage,
len(self._cache.cache),
self._cache.cache.volume() / 1024 ** 2,
),
n=10,
)
def __len__(self):
if self._diskcache_strategy == "batched_static":
return self._size
else:
raise NotImplementedError()
def __getitem__(self, idx):
if self._diskcache_strategy == "naive":
bytes = memoryview(self._cache.cache[str(idx)])
return pickle.loads(bytes)
elif self._diskcache_strategy == "batched_static":
chunk_i, residual = divmod(idx, self._chuck_size)
chunk = self._cache.cache[str(chunk_i)]
bytes = memoryview(chunk[residual])
return pickle.loads(bytes)
else:
raise NotImplementedError()
@classmethod
def get_all_cache_dirs(cls):
"""return all the ocupied cache dirs of DiskCachedDatasetFromList"""
return DiskCachedDatasetFromList._OCCUPIED_CACHE_DIRS
def get_cache_dir(self):
"""return the current cache dirs of DiskCachedDatasetFromList instance"""
return self._cache.cache_dir
@staticmethod
def _clean_up_cache_dir(cache_dir, **kwargs):
print("Cleaning up cache dir: {}".format(cache_dir))
shutil.rmtree(
cache_dir,
onerror=lambda func, path, ex: print(
"Catch error when removing {}; func: {}; exc_info: {}".format(
path, func, ex
)
),
)
@staticmethod
@atexit.register
def _clean_up_all():
# in case the program exists unexpectly, clean all the cache dirs created by
# this session.
if comm.get_local_rank() == 0:
for cache_dir in DiskCachedDatasetFromList.get_all_cache_dirs():
DiskCachedDatasetFromList._clean_up_cache_dir(cache_dir)
def __del__(self):
# when data loader goes are GC-ed, remove the cache dir. This is needed to not
# waste disk space in case that multiple data loaders are used, eg. running
# evaluations on multiple datasets during training.
if comm.get_local_rank() == 0:
DiskCachedDatasetFromList._clean_up_cache_dir(self._cache.cache_dir)
DiskCachedDatasetFromList.get_all_cache_dirs().remove(self._cache.cache_dir)
class _FakeListObj(list):
def __init__(self, size):
self.size = size
def __len__(self):
return self.size
def __getitem__(self, idx):
raise NotImplementedError()
def local_master_get_detection_dataset_dicts(*args, **kwargs):
logger.info("Only load dataset dicts on local master process ...")
dataset_dicts = (
d2_get_detection_dataset_dicts(*args, **kwargs)
if comm.get_local_rank() == 0
else []
)
comm.synchronize()
dataset_size = comm.all_gather(len(dataset_dicts))[0]
if comm.get_local_rank() != 0:
dataset_dicts = _FakeListObj(dataset_size)
return dataset_dicts
@contextlib.contextmanager
def enable_disk_cached_dataset(cfg):
"""
Context manager for enabling disk cache datasets, this is a experimental feature.
- Replace D2's DatasetFromList with DiskCachedDatasetFromList, needs to patch all
call sites.
- Replace D2's get_detection_dataset_dicts with a local-master-only version.
"""
if not cfg.D2GO_DATA.DATASETS.DISK_CACHE.ENABLED:
yield
return
def _patched_dataset_from_list(lst, **kwargs):
logger.info("Patch DatasetFromList with DiskCachedDatasetFromList")
return DiskCachedDatasetFromList(lst)
with contextlib.ExitStack() as stack:
for ctx in [
mock.patch(
"detectron2.data.build.get_detection_dataset_dicts",
side_effect=local_master_get_detection_dataset_dicts,
),
mock.patch(
"detectron2.data.build.DatasetFromList",
side_effect=_patched_dataset_from_list,
),
mock.patch(
"d2go.data.build.DatasetFromList",
side_effect=_patched_dataset_from_list,
),
mock.patch(
"d2go.data.build_fb.DatasetFromList",
side_effect=_patched_dataset_from_list,
),
]:
stack.enter_context(ctx)
yield
...@@ -22,6 +22,7 @@ from d2go.data.dataset_mappers import build_dataset_mapper ...@@ -22,6 +22,7 @@ from d2go.data.dataset_mappers import build_dataset_mapper
from d2go.data.datasets import inject_coco_datasets, register_dynamic_datasets from d2go.data.datasets import inject_coco_datasets, register_dynamic_datasets
from d2go.data.transforms.build import build_transform_gen from d2go.data.transforms.build import build_transform_gen
from d2go.data.utils import ( from d2go.data.utils import (
enable_disk_cached_dataset,
maybe_subsample_n_images, maybe_subsample_n_images,
update_cfg_if_using_adhoc_dataset, update_cfg_if_using_adhoc_dataset,
) )
...@@ -505,12 +506,14 @@ class Detectron2GoRunner(BaseRunner): ...@@ -505,12 +506,14 @@ class Detectron2GoRunner(BaseRunner):
logger.info( logger.info(
"Building detection test loader for dataset: {} ...".format(dataset_name) "Building detection test loader for dataset: {} ...".format(dataset_name)
) )
with enable_disk_cached_dataset(cfg):
mapper = mapper or cls.get_mapper(cfg, is_train=False) mapper = mapper or cls.get_mapper(cfg, is_train=False)
logger.info("Using dataset mapper:\n{}".format(mapper)) logger.info("Using dataset mapper:\n{}".format(mapper))
return d2_build_detection_test_loader(cfg, dataset_name, mapper=mapper) return d2_build_detection_test_loader(cfg, dataset_name, mapper=mapper)
@classmethod @classmethod
def build_detection_train_loader(cls, cfg, *args, mapper=None, **kwargs): def build_detection_train_loader(cls, cfg, *args, mapper=None, **kwargs):
with enable_disk_cached_dataset(cfg):
mapper = mapper or cls.get_mapper(cfg, is_train=True) mapper = mapper or cls.get_mapper(cfg, is_train=True)
data_loader = build_d2go_train_loader(cfg, mapper) data_loader = build_d2go_train_loader(cfg, mapper)
return cls._attach_visualizer_to_data_loader(cfg, data_loader) return cls._attach_visualizer_to_data_loader(cfg, data_loader)
......
...@@ -2,10 +2,17 @@ ...@@ -2,10 +2,17 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
import shutil
import unittest import unittest
import torch
from d2go.data.utils import DiskCachedDatasetFromList, enable_disk_cached_dataset
from d2go.runner import create_runner from d2go.runner import create_runner
from d2go.utils.testing.data_loader_helper import register_toy_coco_dataset from d2go.utils.testing.data_loader_helper import (
create_fake_detection_data_loader,
register_toy_coco_dataset,
)
class TestD2GoDatasetMapper(unittest.TestCase): class TestD2GoDatasetMapper(unittest.TestCase):
...@@ -36,3 +43,84 @@ class TestD2GoDatasetMapper(unittest.TestCase): ...@@ -36,3 +43,84 @@ class TestD2GoDatasetMapper(unittest.TestCase):
for data in test_loader: for data in test_loader:
all_data.append(data) all_data.append(data)
self.assertEqual(len(all_data), 3) self.assertEqual(len(all_data), 3)
class _MyClass(object):
def __init__(self, x):
self.x = x
def do_something(self):
return
class TestDiskCachedDataLoader(unittest.TestCase):
def setUp(self):
# make sure the CACHE_DIR is empty when entering the test
if os.path.exists(DiskCachedDatasetFromList.CACHE_DIR):
shutil.rmtree(DiskCachedDatasetFromList.CACHE_DIR)
def _count_cache_dirs(self):
if not os.path.exists(DiskCachedDatasetFromList.CACHE_DIR):
return 0
return len(os.listdir(DiskCachedDatasetFromList.CACHE_DIR))
def test_disk_cached_dataset_from_list(self):
"""Test the class of DiskCachedDatasetFromList"""
# check the discache can handel different data types
lst = [1, torch.tensor(2), _MyClass(3)]
disk_cached_lst = DiskCachedDatasetFromList(lst)
self.assertEqual(len(disk_cached_lst), 3)
self.assertEqual(disk_cached_lst[0], 1)
self.assertEqual(disk_cached_lst[1].item(), 2)
self.assertEqual(disk_cached_lst[2].x, 3)
# check the cache is created
cache_dir = disk_cached_lst.get_cache_dir()
self.assertTrue(os.path.isdir(cache_dir))
# check the cache is properly released
del disk_cached_lst
self.assertFalse(os.path.isdir(cache_dir))
def test_disk_cached_dataloader(self):
"""Test the data loader backed by disk cache"""
height = 6
width = 8
runner = create_runner("d2go.runner.GeneralizedRCNNRunner")
cfg = runner.get_default_cfg()
def _test_data_loader(data_loader):
first_batch = next(iter(data_loader))
self.assertTrue(first_batch[0]["height"], height)
self.assertTrue(first_batch[0]["width"], width)
# enable the disk cache
cfg.merge_from_list(["D2GO_DATA.DATASETS.DISK_CACHE.ENABLED", "True"])
with enable_disk_cached_dataset(cfg):
# no cache dir in the beginning
self.assertEqual(self._count_cache_dirs(), 0)
with create_fake_detection_data_loader(
height, width, is_train=True
) as train_loader:
# train loader should create one cache dir
self.assertEqual(self._count_cache_dirs(), 1)
_test_data_loader(train_loader)
with create_fake_detection_data_loader(
height, width, is_train=False
) as test_loader:
# test loader should create another cache dir
self.assertEqual(self._count_cache_dirs(), 2)
_test_data_loader(test_loader)
# test loader should release its cache
del test_loader
self.assertEqual(self._count_cache_dirs(), 1)
# no cache dir in the end
del train_loader
self.assertEqual(self._count_cache_dirs(), 0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment