Commit d3115faf authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

delay import for discache

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/185

The `DiskCachedDatasetFromList` was originally in the `d2go/data/utils.py`, so the class is declared by default. Therefore the clean up call (https://fburl.com/code/cu7hswhx) is always called even when the feature is not enabled. This diff move it to a new place and delay the import, so the clean up won't run.

Reviewed By: tglik

Differential Revision: D34601363

fbshipit-source-id: 734bb9b2c7957d7437ad40c4bfe60a441ec2f23a
parent d369931a
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import atexit
import logging
import pickle
import shutil
import uuid
import numpy as np
import torch.utils.data as data
from detectron2.utils import comm
from detectron2.utils.logger import log_every_n_seconds
logger = logging.getLogger(__name__)
def _local_master_gather(func, check_equal=False):
if comm.get_local_rank() == 0:
x = func()
assert x is not None
else:
x = None
x_all = comm.all_gather(x)
x_local_master = [x for x in x_all if x is not None]
if check_equal:
master = x_local_master[0]
assert all(x == master for x in x_local_master), x_local_master
return x_local_master
class DiskCachedDatasetFromList(data.Dataset):
"""
Wrap a list to a torch Dataset, the underlying storage is off-loaded to disk to
save RAM usage.
"""
CACHE_DIR = "/tmp/DatasetFromList_cache"
_OCCUPIED_CACHE_DIRS = set()
def __init__(self, lst, strategy="batched_static"):
"""
Args:
lst (list): a list which contains elements to produce.
strategy (str): strategy of using diskcache, supported strategies:
- native: saving each item individually.
- batched_static: group N items together, where N is calculated from
the average item size.
"""
self._lst = lst
self._diskcache_strategy = strategy
def _serialize(data):
buffer = pickle.dumps(data, protocol=-1)
return np.frombuffer(buffer, dtype=np.uint8)
logger.info(
"Serializing {} elements to byte tensors and concatenating them all ...".format(
len(self._lst)
)
)
self._lst = [_serialize(x) for x in self._lst]
# TODO: only enabling DiskCachedDataset for large enough dataset
logger.info(
"Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024 ** 2)
)
self._initialize_diskcache()
def _initialize_diskcache(self):
from mobile_cv.common.misc.local_cache import LocalCache
cache_dir = "{}/{}".format(
DiskCachedDatasetFromList.CACHE_DIR, uuid.uuid4().hex[:8]
)
cache_dir = comm.all_gather(cache_dir)[0] # use same cache_dir
logger.info("Creating diskcache database in: {}".format(cache_dir))
self._cache = LocalCache(cache_dir=cache_dir, num_shards=8)
# self._cache.cache.clear(retry=True) # seems faster if index exists
if comm.get_local_rank() == 0:
DiskCachedDatasetFromList.get_all_cache_dirs().add(self._cache.cache_dir)
if self._diskcache_strategy == "naive":
for i, item in enumerate(self._lst):
ret = self._write_to_local_db((i, item))
assert ret, "Error writing index {} to local db".format(i)
pct = 100.0 * i / len(self._lst)
self._log_progress(pct)
# NOTE: each item might be small in size (hundreds of bytes),
# writing million of them can take a pretty long time (hours)
# because of frequent disk access. One solution is grouping a batch
# of items into larger blob.
elif self._diskcache_strategy == "batched_static":
TARGET_BYTES = 50 * 1024
average_bytes = np.average(
[
self._lst[int(x)].size
for x in np.linspace(0, len(self._lst) - 1, 1000)
]
)
self._chuck_size = max(1, int(TARGET_BYTES / average_bytes))
logger.info(
"Average data size: {} bytes; target chuck data size {} KiB;"
" {} items per chuck; {} chucks in total".format(
average_bytes,
TARGET_BYTES / 1024,
self._chuck_size,
int(len(self._lst) / self._chuck_size),
)
)
for i in range(0, len(self._lst), self._chuck_size):
chunk = self._lst[i : i + self._chuck_size]
chunk_i = int(i / self._chuck_size)
ret = self._write_to_local_db((chunk_i, chunk))
assert ret, "Error writing index {} to local db".format(chunk_i)
pct = 100.0 * i / len(self._lst)
self._log_progress(pct)
# NOTE: instead of using fixed chuck size, items can be grouped dynamically
elif self._diskcache_strategy == "batched_dynamic":
raise NotImplementedError()
else:
raise NotImplementedError(self._diskcache_strategy)
comm.synchronize()
logger.info(
"Finished writing to local disk, db size: {:.2f} MiB".format(
self._cache.cache.volume() / 1024 ** 2
)
)
# Optional sync for some strategies
if self._diskcache_strategy == "batched_static":
# propagate chuck size and make sure all local rank 0 uses the same value
self._chuck_size = _local_master_gather(
lambda: self._chuck_size, check_equal=True
)[0]
# free the memory of self._lst
self._size = _local_master_gather(lambda: len(self._lst), check_equal=True)[0]
del self._lst
def _write_to_local_db(self, task):
index, record = task
db_path = str(index)
# suc = self._cache.load(lambda path, x: x, db_path, record)
# record = BytesIO(np.random.bytes(np.random.randint(70000, 90000)))
suc = self._cache.cache.set(db_path, record, retry=True)
return suc
def _log_progress(self, percentage):
log_every_n_seconds(
logging.INFO,
"({:.2f}%) Wrote {} elements to local disk cache, db size: {:.2f} MiB".format(
percentage,
len(self._cache.cache),
self._cache.cache.volume() / 1024 ** 2,
),
n=10,
)
def __len__(self):
if self._diskcache_strategy == "batched_static":
return self._size
else:
raise NotImplementedError()
def __getitem__(self, idx):
if self._diskcache_strategy == "naive":
bytes = memoryview(self._cache.cache[str(idx)])
return pickle.loads(bytes)
elif self._diskcache_strategy == "batched_static":
chunk_i, residual = divmod(idx, self._chuck_size)
chunk = self._cache.cache[str(chunk_i)]
bytes = memoryview(chunk[residual])
return pickle.loads(bytes)
else:
raise NotImplementedError()
@classmethod
def get_all_cache_dirs(cls):
"""return all the ocupied cache dirs of DiskCachedDatasetFromList"""
return DiskCachedDatasetFromList._OCCUPIED_CACHE_DIRS
def get_cache_dir(self):
"""return the current cache dirs of DiskCachedDatasetFromList instance"""
return self._cache.cache_dir
@staticmethod
def _clean_up_cache_dir(cache_dir, **kwargs):
print("Cleaning up cache dir: {}".format(cache_dir))
shutil.rmtree(
cache_dir,
onerror=lambda func, path, ex: print(
"Catch error when removing {}; func: {}; exc_info: {}".format(
path, func, ex
)
),
)
@staticmethod
@atexit.register
def _clean_up_all():
# in case the program exists unexpectly, clean all the cache dirs created by
# this session.
if comm.get_local_rank() == 0:
for cache_dir in DiskCachedDatasetFromList.get_all_cache_dirs():
DiskCachedDatasetFromList._clean_up_cache_dir(cache_dir)
def __del__(self):
# when data loader goes are GC-ed, remove the cache dir. This is needed to not
# waste disk space in case that multiple data loaders are used, eg. running
# evaluations on multiple datasets during training.
if comm.get_local_rank() == 0:
DiskCachedDatasetFromList._clean_up_cache_dir(self._cache.cache_dir)
DiskCachedDatasetFromList.get_all_cache_dirs().remove(self._cache.cache_dir)
......@@ -7,11 +7,9 @@ import contextlib
import json
import logging
import os
import pickle
import re
import shutil
import tempfile
import uuid
from collections import defaultdict
from unittest import mock
......@@ -30,7 +28,6 @@ from detectron2.data.build import (
)
from detectron2.utils import comm
from detectron2.utils.file_io import PathManager
from detectron2.utils.logger import log_every_n_seconds
logger = logging.getLogger(__name__)
......@@ -387,212 +384,6 @@ def update_cfg_if_using_adhoc_dataset(cfg):
return cfg
def _local_master_gather(func, check_equal=False):
if comm.get_local_rank() == 0:
x = func()
assert x is not None
else:
x = None
x_all = comm.all_gather(x)
x_local_master = [x for x in x_all if x is not None]
if check_equal:
master = x_local_master[0]
assert all(x == master for x in x_local_master), x_local_master
return x_local_master
class DiskCachedDatasetFromList(data.Dataset):
"""
Wrap a list to a torch Dataset, the underlying storage is off-loaded to disk to
save RAM usage.
"""
CACHE_DIR = "/tmp/DatasetFromList_cache"
_OCCUPIED_CACHE_DIRS = set()
def __init__(self, lst, strategy="batched_static"):
"""
Args:
lst (list): a list which contains elements to produce.
strategy (str): strategy of using diskcache, supported strategies:
- native: saving each item individually.
- batched_static: group N items together, where N is calculated from
the average item size.
"""
self._lst = lst
self._diskcache_strategy = strategy
def _serialize(data):
buffer = pickle.dumps(data, protocol=-1)
return np.frombuffer(buffer, dtype=np.uint8)
logger.info(
"Serializing {} elements to byte tensors and concatenating them all ...".format(
len(self._lst)
)
)
self._lst = [_serialize(x) for x in self._lst]
# TODO: only enabling DiskCachedDataset for large enough dataset
logger.info(
"Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024 ** 2)
)
self._initialize_diskcache()
def _initialize_diskcache(self):
from mobile_cv.common.misc.local_cache import LocalCache
cache_dir = "{}/{}".format(
DiskCachedDatasetFromList.CACHE_DIR, uuid.uuid4().hex[:8]
)
cache_dir = comm.all_gather(cache_dir)[0] # use same cache_dir
logger.info("Creating diskcache database in: {}".format(cache_dir))
self._cache = LocalCache(cache_dir=cache_dir, num_shards=8)
# self._cache.cache.clear(retry=True) # seems faster if index exists
if comm.get_local_rank() == 0:
DiskCachedDatasetFromList.get_all_cache_dirs().add(self._cache.cache_dir)
if self._diskcache_strategy == "naive":
for i, item in enumerate(self._lst):
ret = self._write_to_local_db((i, item))
assert ret, "Error writing index {} to local db".format(i)
pct = 100.0 * i / len(self._lst)
self._log_progress(pct)
# NOTE: each item might be small in size (hundreds of bytes),
# writing million of them can take a pretty long time (hours)
# because of frequent disk access. One solution is grouping a batch
# of items into larger blob.
elif self._diskcache_strategy == "batched_static":
TARGET_BYTES = 50 * 1024
average_bytes = np.average(
[
self._lst[int(x)].size
for x in np.linspace(0, len(self._lst) - 1, 1000)
]
)
self._chuck_size = max(1, int(TARGET_BYTES / average_bytes))
logger.info(
"Average data size: {} bytes; target chuck data size {} KiB;"
" {} items per chuck; {} chucks in total".format(
average_bytes,
TARGET_BYTES / 1024,
self._chuck_size,
int(len(self._lst) / self._chuck_size),
)
)
for i in range(0, len(self._lst), self._chuck_size):
chunk = self._lst[i : i + self._chuck_size]
chunk_i = int(i / self._chuck_size)
ret = self._write_to_local_db((chunk_i, chunk))
assert ret, "Error writing index {} to local db".format(chunk_i)
pct = 100.0 * i / len(self._lst)
self._log_progress(pct)
# NOTE: instead of using fixed chuck size, items can be grouped dynamically
elif self._diskcache_strategy == "batched_dynamic":
raise NotImplementedError()
else:
raise NotImplementedError(self._diskcache_strategy)
comm.synchronize()
logger.info(
"Finished writing to local disk, db size: {:.2f} MiB".format(
self._cache.cache.volume() / 1024 ** 2
)
)
# Optional sync for some strategies
if self._diskcache_strategy == "batched_static":
# propagate chuck size and make sure all local rank 0 uses the same value
self._chuck_size = _local_master_gather(
lambda: self._chuck_size, check_equal=True
)[0]
# free the memory of self._lst
self._size = _local_master_gather(lambda: len(self._lst), check_equal=True)[0]
del self._lst
def _write_to_local_db(self, task):
index, record = task
db_path = str(index)
# suc = self._cache.load(lambda path, x: x, db_path, record)
# record = BytesIO(np.random.bytes(np.random.randint(70000, 90000)))
suc = self._cache.cache.set(db_path, record, retry=True)
return suc
def _log_progress(self, percentage):
log_every_n_seconds(
logging.INFO,
"({:.2f}%) Wrote {} elements to local disk cache, db size: {:.2f} MiB".format(
percentage,
len(self._cache.cache),
self._cache.cache.volume() / 1024 ** 2,
),
n=10,
)
def __len__(self):
if self._diskcache_strategy == "batched_static":
return self._size
else:
raise NotImplementedError()
def __getitem__(self, idx):
if self._diskcache_strategy == "naive":
bytes = memoryview(self._cache.cache[str(idx)])
return pickle.loads(bytes)
elif self._diskcache_strategy == "batched_static":
chunk_i, residual = divmod(idx, self._chuck_size)
chunk = self._cache.cache[str(chunk_i)]
bytes = memoryview(chunk[residual])
return pickle.loads(bytes)
else:
raise NotImplementedError()
@classmethod
def get_all_cache_dirs(cls):
"""return all the ocupied cache dirs of DiskCachedDatasetFromList"""
return DiskCachedDatasetFromList._OCCUPIED_CACHE_DIRS
def get_cache_dir(self):
"""return the current cache dirs of DiskCachedDatasetFromList instance"""
return self._cache.cache_dir
@staticmethod
def _clean_up_cache_dir(cache_dir, **kwargs):
print("Cleaning up cache dir: {}".format(cache_dir))
shutil.rmtree(
cache_dir,
onerror=lambda func, path, ex: print(
"Catch error when removing {}; func: {}; exc_info: {}".format(
path, func, ex
)
),
)
@staticmethod
@atexit.register
def _clean_up_all():
# in case the program exists unexpectly, clean all the cache dirs created by
# this session.
if comm.get_local_rank() == 0:
for cache_dir in DiskCachedDatasetFromList.get_all_cache_dirs():
DiskCachedDatasetFromList._clean_up_cache_dir(cache_dir)
def __del__(self):
# when data loader goes are GC-ed, remove the cache dir. This is needed to not
# waste disk space in case that multiple data loaders are used, eg. running
# evaluations on multiple datasets during training.
if comm.get_local_rank() == 0:
DiskCachedDatasetFromList._clean_up_cache_dir(self._cache.cache_dir)
DiskCachedDatasetFromList.get_all_cache_dirs().remove(self._cache.cache_dir)
class _FakeListObj(list):
def __init__(self, size):
self.size = size
......@@ -635,6 +426,8 @@ def enable_disk_cached_dataset(cfg):
return
def _patched_dataset_from_list(lst, **kwargs):
from d2go.data.disk_cache import DiskCachedDatasetFromList
logger.info("Patch DatasetFromList with DiskCachedDatasetFromList")
return DiskCachedDatasetFromList(lst)
......
......@@ -30,6 +30,7 @@ requirements = [
"pytorch-lightning @ git+https://github.com/PyTorchLightning/pytorch-lightning.git@86b177ebe",
"opencv-python",
"parameterized",
"diskcache", # TODO: move to mobile_cv
]
......
......@@ -7,7 +7,8 @@ import shutil
import unittest
import torch
from d2go.data.utils import DiskCachedDatasetFromList, enable_disk_cached_dataset
from d2go.data.disk_cache import DiskCachedDatasetFromList
from d2go.data.utils import enable_disk_cached_dataset
from d2go.runner import create_runner
from d2go.utils.testing.data_loader_helper import (
create_fake_detection_data_loader,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment