Commit 33ca49ac authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

better cleanup for disk_cache

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/267

https://www.internalfb.com/intern/test/281475036363610?ref_report_id=0 is flaky, which is caused by running multiple tests at the same time, and clean up is not handled very well in that case.

Reviewed By: tglik

Differential Revision: D36787035

fbshipit-source-id: 6a478318fe011af936dd10fa564519c8c0615ed3
parent 057dc5d2
......@@ -14,6 +14,16 @@ from detectron2.utils.logger import log_every_n_seconds
logger = logging.getLogger(__name__)
# NOTE: Use unique ROOT_CACHE_DIR for each run, during the run, each instance of data
# loader will create a `cache_dir` under ROOT_CACHE_DIR. When the DL instance is GC-ed,
# the `cache_dir` will be removed by __del__; when the run is finished or interrupted,
# atexit.register will be triggered to remove the ROOT_CACHE_DIR to make sure there's no
# leftovers. Regarding DDP, although each GPU process has their own random value for
# ROOT_CACHE_DIR, but each GPU process uses the same `cache_dir` broadcasted from local
# master rank, which is then inherited by each data loader worker, this makes sure that
# `cache_dir` is in-sync between all GPUs and DL works on the same node.
ROOT_CACHE_DIR = "/tmp/DatasetFromList_cache_" + uuid.uuid4().hex[:8]
def _local_master_gather(func, check_equal=False):
if comm.get_local_rank() == 0:
......@@ -37,9 +47,6 @@ class DiskCachedDatasetFromList(data.Dataset):
save RAM usage.
"""
CACHE_DIR = "/tmp/DatasetFromList_cache"
_OCCUPIED_CACHE_DIRS = set()
def __init__(self, lst, strategy="batched_static"):
"""
Args:
......@@ -70,16 +77,13 @@ class DiskCachedDatasetFromList(data.Dataset):
def _initialize_diskcache(self):
from mobile_cv.common.misc.local_cache import LocalCache
cache_dir = "{}/{}".format(
DiskCachedDatasetFromList.CACHE_DIR, uuid.uuid4().hex[:8]
)
cache_dir = "{}/{}".format(ROOT_CACHE_DIR, uuid.uuid4().hex[:8])
cache_dir = comm.all_gather(cache_dir)[0] # use same cache_dir
logger.info("Creating diskcache database in: {}".format(cache_dir))
self._cache = LocalCache(cache_dir=cache_dir, num_shards=8)
# self._cache.cache.clear(retry=True) # seems faster if index exists
if comm.get_local_rank() == 0:
DiskCachedDatasetFromList.get_all_cache_dirs().add(self._cache.cache_dir)
if self._diskcache_strategy == "naive":
for i, item in enumerate(self._lst):
......@@ -183,40 +187,34 @@ class DiskCachedDatasetFromList(data.Dataset):
else:
raise NotImplementedError()
@classmethod
def get_all_cache_dirs(cls):
"""return all the ocupied cache dirs of DiskCachedDatasetFromList"""
return DiskCachedDatasetFromList._OCCUPIED_CACHE_DIRS
def get_cache_dir(self):
@property
def cache_dir(self):
"""return the current cache dirs of DiskCachedDatasetFromList instance"""
return self._cache.cache_dir
@staticmethod
def _clean_up_cache_dir(cache_dir, **kwargs):
print("Cleaning up cache dir: {}".format(cache_dir))
shutil.rmtree(
cache_dir,
onerror=lambda func, path, ex: print(
"Catch error when removing {}; func: {}; exc_info: {}".format(
path, func, ex
)
),
)
@staticmethod
@atexit.register
def _clean_up_all():
def _clean_up_root_cache_dir():
# in case the program exists unexpectly, clean all the cache dirs created by
# this session.
if comm.get_local_rank() == 0:
for cache_dir in DiskCachedDatasetFromList.get_all_cache_dirs():
DiskCachedDatasetFromList._clean_up_cache_dir(cache_dir)
_clean_up_cache_dir(ROOT_CACHE_DIR)
def __del__(self):
# when data loader goes are GC-ed, remove the cache dir. This is needed to not
# waste disk space in case that multiple data loaders are used, eg. running
# evaluations on multiple datasets during training.
if comm.get_local_rank() == 0:
DiskCachedDatasetFromList._clean_up_cache_dir(self._cache.cache_dir)
DiskCachedDatasetFromList.get_all_cache_dirs().remove(self._cache.cache_dir)
_clean_up_cache_dir(self.cache_dir)
def _clean_up_cache_dir(cache_dir):
print("Cleaning up cache dir: {}".format(cache_dir))
shutil.rmtree(
cache_dir,
onerror=lambda func, path, ex: print(
"Catch error when removing {}; func: {}; exc_info: {}".format(
path, func, ex
)
),
)
......@@ -4,10 +4,11 @@
import os
import shutil
import tempfile
import unittest
import torch
from d2go.data.disk_cache import DiskCachedDatasetFromList
from d2go.data.disk_cache import DiskCachedDatasetFromList, ROOT_CACHE_DIR
from d2go.data.utils import enable_disk_cached_dataset
from d2go.runner import create_runner
from d2go.utils.testing.data_loader_helper import (
......@@ -22,11 +23,16 @@ class TestD2GoDatasetMapper(unittest.TestCase):
data loader in GeneralizedRCNNRunner (the default runner) in Detectron2Go.
"""
def setUp(self):
self.output_dir = tempfile.mkdtemp(prefix="TestD2GoDatasetMapper_")
self.addCleanup(shutil.rmtree, self.output_dir)
def test_default_dataset(self):
runner = create_runner("d2go.runner.GeneralizedRCNNRunner")
cfg = runner.get_default_cfg()
cfg.DATASETS.TRAIN = ["default_dataset_train"]
cfg.DATASETS.TEST = ["default_dataset_test"]
cfg.OUTPUT_DIR = self.output_dir
with register_toy_coco_dataset("default_dataset_train", num_images=3):
train_loader = runner.build_detection_train_loader(cfg)
......@@ -56,15 +62,18 @@ class _MyClass(object):
class TestDiskCachedDataLoader(unittest.TestCase):
def setUp(self):
# make sure the CACHE_DIR is empty when entering the test
if os.path.exists(DiskCachedDatasetFromList.CACHE_DIR):
shutil.rmtree(DiskCachedDatasetFromList.CACHE_DIR)
# make sure the ROOT_CACHE_DIR is empty when entering the test
if os.path.exists(ROOT_CACHE_DIR):
shutil.rmtree(ROOT_CACHE_DIR)
self.output_dir = tempfile.mkdtemp(prefix="TestDiskCachedDataLoader_")
self.addCleanup(shutil.rmtree, self.output_dir)
def _count_cache_dirs(self):
if not os.path.exists(DiskCachedDatasetFromList.CACHE_DIR):
if not os.path.exists(ROOT_CACHE_DIR):
return 0
return len(os.listdir(DiskCachedDatasetFromList.CACHE_DIR))
return len(os.listdir(ROOT_CACHE_DIR))
def test_disk_cached_dataset_from_list(self):
"""Test the class of DiskCachedDatasetFromList"""
......@@ -77,7 +86,7 @@ class TestDiskCachedDataLoader(unittest.TestCase):
self.assertEqual(disk_cached_lst[2].x, 3)
# check the cache is created
cache_dir = disk_cached_lst.get_cache_dir()
cache_dir = disk_cached_lst.cache_dir
self.assertTrue(os.path.isdir(cache_dir))
# check the cache is properly released
......@@ -90,6 +99,8 @@ class TestDiskCachedDataLoader(unittest.TestCase):
width = 8
runner = create_runner("d2go.runner.GeneralizedRCNNRunner")
cfg = runner.get_default_cfg()
cfg.OUTPUT_DIR = self.output_dir
cfg.DATALOADER.NUM_WORKERS = 2
def _test_data_loader(data_loader):
first_batch = next(iter(data_loader))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment