Commit 33ca49ac authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

better cleanup for disk_cache

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/267

https://www.internalfb.com/intern/test/281475036363610?ref_report_id=0 is flaky, which is caused by running multiple tests at the same time, and clean up is not handled very well in that case.

Reviewed By: tglik

Differential Revision: D36787035

fbshipit-source-id: 6a478318fe011af936dd10fa564519c8c0615ed3
parent 057dc5d2
...@@ -14,6 +14,16 @@ from detectron2.utils.logger import log_every_n_seconds ...@@ -14,6 +14,16 @@ from detectron2.utils.logger import log_every_n_seconds
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# NOTE: Use unique ROOT_CACHE_DIR for each run, during the run, each instance of data
# loader will create a `cache_dir` under ROOT_CACHE_DIR. When the DL instance is GC-ed,
# the `cache_dir` will be removed by __del__; when the run is finished or interrupted,
# atexit.register will be triggered to remove the ROOT_CACHE_DIR to make sure there's no
# leftovers. Regarding DDP, although each GPU process has their own random value for
# ROOT_CACHE_DIR, but each GPU process uses the same `cache_dir` broadcasted from local
# master rank, which is then inherited by each data loader worker, this makes sure that
# `cache_dir` is in-sync between all GPUs and DL works on the same node.
ROOT_CACHE_DIR = "/tmp/DatasetFromList_cache_" + uuid.uuid4().hex[:8]
def _local_master_gather(func, check_equal=False): def _local_master_gather(func, check_equal=False):
if comm.get_local_rank() == 0: if comm.get_local_rank() == 0:
...@@ -37,9 +47,6 @@ class DiskCachedDatasetFromList(data.Dataset): ...@@ -37,9 +47,6 @@ class DiskCachedDatasetFromList(data.Dataset):
save RAM usage. save RAM usage.
""" """
CACHE_DIR = "/tmp/DatasetFromList_cache"
_OCCUPIED_CACHE_DIRS = set()
def __init__(self, lst, strategy="batched_static"): def __init__(self, lst, strategy="batched_static"):
""" """
Args: Args:
...@@ -70,16 +77,13 @@ class DiskCachedDatasetFromList(data.Dataset): ...@@ -70,16 +77,13 @@ class DiskCachedDatasetFromList(data.Dataset):
def _initialize_diskcache(self): def _initialize_diskcache(self):
from mobile_cv.common.misc.local_cache import LocalCache from mobile_cv.common.misc.local_cache import LocalCache
cache_dir = "{}/{}".format( cache_dir = "{}/{}".format(ROOT_CACHE_DIR, uuid.uuid4().hex[:8])
DiskCachedDatasetFromList.CACHE_DIR, uuid.uuid4().hex[:8]
)
cache_dir = comm.all_gather(cache_dir)[0] # use same cache_dir cache_dir = comm.all_gather(cache_dir)[0] # use same cache_dir
logger.info("Creating diskcache database in: {}".format(cache_dir)) logger.info("Creating diskcache database in: {}".format(cache_dir))
self._cache = LocalCache(cache_dir=cache_dir, num_shards=8) self._cache = LocalCache(cache_dir=cache_dir, num_shards=8)
# self._cache.cache.clear(retry=True) # seems faster if index exists # self._cache.cache.clear(retry=True) # seems faster if index exists
if comm.get_local_rank() == 0: if comm.get_local_rank() == 0:
DiskCachedDatasetFromList.get_all_cache_dirs().add(self._cache.cache_dir)
if self._diskcache_strategy == "naive": if self._diskcache_strategy == "naive":
for i, item in enumerate(self._lst): for i, item in enumerate(self._lst):
...@@ -183,40 +187,34 @@ class DiskCachedDatasetFromList(data.Dataset): ...@@ -183,40 +187,34 @@ class DiskCachedDatasetFromList(data.Dataset):
else: else:
raise NotImplementedError() raise NotImplementedError()
@classmethod @property
def get_all_cache_dirs(cls): def cache_dir(self):
"""return all the ocupied cache dirs of DiskCachedDatasetFromList"""
return DiskCachedDatasetFromList._OCCUPIED_CACHE_DIRS
def get_cache_dir(self):
"""return the current cache dirs of DiskCachedDatasetFromList instance""" """return the current cache dirs of DiskCachedDatasetFromList instance"""
return self._cache.cache_dir return self._cache.cache_dir
@staticmethod
def _clean_up_cache_dir(cache_dir, **kwargs):
print("Cleaning up cache dir: {}".format(cache_dir))
shutil.rmtree(
cache_dir,
onerror=lambda func, path, ex: print(
"Catch error when removing {}; func: {}; exc_info: {}".format(
path, func, ex
)
),
)
@staticmethod @staticmethod
@atexit.register @atexit.register
def _clean_up_all(): def _clean_up_root_cache_dir():
# in case the program exists unexpectly, clean all the cache dirs created by # in case the program exists unexpectly, clean all the cache dirs created by
# this session. # this session.
if comm.get_local_rank() == 0: if comm.get_local_rank() == 0:
for cache_dir in DiskCachedDatasetFromList.get_all_cache_dirs(): _clean_up_cache_dir(ROOT_CACHE_DIR)
DiskCachedDatasetFromList._clean_up_cache_dir(cache_dir)
def __del__(self): def __del__(self):
# when data loader goes are GC-ed, remove the cache dir. This is needed to not # when data loader goes are GC-ed, remove the cache dir. This is needed to not
# waste disk space in case that multiple data loaders are used, eg. running # waste disk space in case that multiple data loaders are used, eg. running
# evaluations on multiple datasets during training. # evaluations on multiple datasets during training.
if comm.get_local_rank() == 0: if comm.get_local_rank() == 0:
DiskCachedDatasetFromList._clean_up_cache_dir(self._cache.cache_dir) _clean_up_cache_dir(self.cache_dir)
DiskCachedDatasetFromList.get_all_cache_dirs().remove(self._cache.cache_dir)
def _clean_up_cache_dir(cache_dir):
print("Cleaning up cache dir: {}".format(cache_dir))
shutil.rmtree(
cache_dir,
onerror=lambda func, path, ex: print(
"Catch error when removing {}; func: {}; exc_info: {}".format(
path, func, ex
)
),
)
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
import os import os
import shutil import shutil
import tempfile
import unittest import unittest
import torch import torch
from d2go.data.disk_cache import DiskCachedDatasetFromList from d2go.data.disk_cache import DiskCachedDatasetFromList, ROOT_CACHE_DIR
from d2go.data.utils import enable_disk_cached_dataset from d2go.data.utils import enable_disk_cached_dataset
from d2go.runner import create_runner from d2go.runner import create_runner
from d2go.utils.testing.data_loader_helper import ( from d2go.utils.testing.data_loader_helper import (
...@@ -22,11 +23,16 @@ class TestD2GoDatasetMapper(unittest.TestCase): ...@@ -22,11 +23,16 @@ class TestD2GoDatasetMapper(unittest.TestCase):
data loader in GeneralizedRCNNRunner (the default runner) in Detectron2Go. data loader in GeneralizedRCNNRunner (the default runner) in Detectron2Go.
""" """
def setUp(self):
self.output_dir = tempfile.mkdtemp(prefix="TestD2GoDatasetMapper_")
self.addCleanup(shutil.rmtree, self.output_dir)
def test_default_dataset(self): def test_default_dataset(self):
runner = create_runner("d2go.runner.GeneralizedRCNNRunner") runner = create_runner("d2go.runner.GeneralizedRCNNRunner")
cfg = runner.get_default_cfg() cfg = runner.get_default_cfg()
cfg.DATASETS.TRAIN = ["default_dataset_train"] cfg.DATASETS.TRAIN = ["default_dataset_train"]
cfg.DATASETS.TEST = ["default_dataset_test"] cfg.DATASETS.TEST = ["default_dataset_test"]
cfg.OUTPUT_DIR = self.output_dir
with register_toy_coco_dataset("default_dataset_train", num_images=3): with register_toy_coco_dataset("default_dataset_train", num_images=3):
train_loader = runner.build_detection_train_loader(cfg) train_loader = runner.build_detection_train_loader(cfg)
...@@ -56,15 +62,18 @@ class _MyClass(object): ...@@ -56,15 +62,18 @@ class _MyClass(object):
class TestDiskCachedDataLoader(unittest.TestCase): class TestDiskCachedDataLoader(unittest.TestCase):
def setUp(self): def setUp(self):
# make sure the CACHE_DIR is empty when entering the test # make sure the ROOT_CACHE_DIR is empty when entering the test
if os.path.exists(DiskCachedDatasetFromList.CACHE_DIR): if os.path.exists(ROOT_CACHE_DIR):
shutil.rmtree(DiskCachedDatasetFromList.CACHE_DIR) shutil.rmtree(ROOT_CACHE_DIR)
self.output_dir = tempfile.mkdtemp(prefix="TestDiskCachedDataLoader_")
self.addCleanup(shutil.rmtree, self.output_dir)
def _count_cache_dirs(self): def _count_cache_dirs(self):
if not os.path.exists(DiskCachedDatasetFromList.CACHE_DIR): if not os.path.exists(ROOT_CACHE_DIR):
return 0 return 0
return len(os.listdir(DiskCachedDatasetFromList.CACHE_DIR)) return len(os.listdir(ROOT_CACHE_DIR))
def test_disk_cached_dataset_from_list(self): def test_disk_cached_dataset_from_list(self):
"""Test the class of DiskCachedDatasetFromList""" """Test the class of DiskCachedDatasetFromList"""
...@@ -77,7 +86,7 @@ class TestDiskCachedDataLoader(unittest.TestCase): ...@@ -77,7 +86,7 @@ class TestDiskCachedDataLoader(unittest.TestCase):
self.assertEqual(disk_cached_lst[2].x, 3) self.assertEqual(disk_cached_lst[2].x, 3)
# check the cache is created # check the cache is created
cache_dir = disk_cached_lst.get_cache_dir() cache_dir = disk_cached_lst.cache_dir
self.assertTrue(os.path.isdir(cache_dir)) self.assertTrue(os.path.isdir(cache_dir))
# check the cache is properly released # check the cache is properly released
...@@ -90,6 +99,8 @@ class TestDiskCachedDataLoader(unittest.TestCase): ...@@ -90,6 +99,8 @@ class TestDiskCachedDataLoader(unittest.TestCase):
width = 8 width = 8
runner = create_runner("d2go.runner.GeneralizedRCNNRunner") runner = create_runner("d2go.runner.GeneralizedRCNNRunner")
cfg = runner.get_default_cfg() cfg = runner.get_default_cfg()
cfg.OUTPUT_DIR = self.output_dir
cfg.DATALOADER.NUM_WORKERS = 2
def _test_data_loader(data_loader): def _test_data_loader(data_loader):
first_batch = next(iter(data_loader)) first_batch = next(iter(data_loader))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment