Commit 01c351bc authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

use SharedList as offload backend of DatasetFromList by default

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/405

- Use the non-hacky way (added in D40818736, https://github.com/facebookresearch/detectron2/pull/4626) to customize offloaded backend for DatasetFromList.
- In `D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb)Go`, switch to use `SharedList` (added in D40789062, https://github.com/facebookresearch/mobile-vision/pull/120) by default to save RAM and optionally use `DiskCachedList` to further save RAM.

Local benchmarking results (using a ~2.4 GiB dataset) using dev mode:
| RAM usage (RES, SHR) | No-dataset | Naive | NumpySerializedList | SharedList | DiskCachedList |
| -- | -- | -- | -- | -- | -- |
| Master GPU worker.         | 8.0g, 2.8g | 21.4g, 2.8g | 11.6g, 2.8g | 11.5g, 5.2g | -- |
| Non-master GPU worker  | 7.5g, 2.8g | 21.0g, 2.8g | 11.5g, 2.8g | 8.0g, 2.8g | -- |
| Per data loader worker     | 2.0g, 1.0g | 14.0g, 1.0g | 4.4g, 1.0g | 2.1g, 1.0g | -- |

- The memory usage (RES, SHR) is found from `top` command. `RES` is total memory used per process; `SHR` shows how much RAM can be shared inside `RES`.
- experiments are done using 2 GPU and 2 data loader workers per GPU, so there're 6 processes in total, the **numbers are per-process**.
- `No-dataset`: running the same job with tiny dataset (only 4.47 MiB after serialization), since RAM usage should be negligible, it shows the floor RAM usage.
- other experiments are running using a dataset of the size of **2413.57 MiB** after serialization.
  - `Naive`: vanilla version if we don't offload the dataset to other storage.
  - `NumpySerializedList`: this optimization was added a long time ago in D19896490. I recalled that the RAM was indeed shared for data loader worker, but seems that there was a regression. Now basically all the processes have a copy of data.
  - `SharedList`: is enabled in this diff. It shows that only the master GPU needs extra RAM. It's interesting that it uses 3.5GB RAM more than other rank, while the data itself is 2.4GB. I'm not so sure if it's overhead of the storage itself or the overhead caused by sharing it with other processes, since non-master GPU using `NumpySerializedList` also uses 11.5g of RAM, we probably don't need to worry too much about it.
  - `DiskCachedList`: didn't benchmark, should have no extra RAM usage.

Using the above number for a typical 8GPU, 4worker training, assuming the OS and other programs take 20-30GB RAM, the current training will use `11.6g * 8 + 4.4g * 8*4 = 233.6g` RAM, on the edge of causing OOM for a 256gb machine. This aligns with our experience that it supports ~2GB dataset. After the change, the training will use only `(11.5g * 7 + 8.0g) + 2.1g * 8*4 = 155.7g` RAM, which gives a much larger head room, we can thus train with much larger dataset (eg. 20GB) or use more DL workers (eg. 8 workers).

Reviewed By: sstsai-adl

Differential Revision: D40819959

fbshipit-source-id: fbdc9d2d1d440e14ae8496be65979a09f3ed3638
parent c6666d33
......@@ -8,7 +8,6 @@ import shutil
import uuid
import numpy as np
import torch.utils.data as data
from detectron2.utils import comm
from detectron2.utils.logger import log_every_n_seconds
......@@ -41,10 +40,9 @@ def _local_master_gather(func, check_equal=False):
return x_local_master
class DiskCachedDatasetFromList(data.Dataset):
class DiskCachedList(object):
"""
Wrap a list to a torch Dataset, the underlying storage is off-loaded to disk to
save RAM usage.
Wrap a list, the underlying storage is off-loaded to disk to save RAM usage.
"""
def __init__(self, lst, strategy="batched_static"):
......
......@@ -26,8 +26,10 @@ from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.data.build import (
get_detection_dataset_dicts as d2_get_detection_dataset_dicts,
)
from detectron2.data.common import set_default_dataset_from_list_serialize_method
from detectron2.utils import comm
from detectron2.utils.file_io import PathManager
from mobile_cv.torch.utils_pytorch.shareables import SharedList
logger = logging.getLogger(__name__)
......@@ -433,49 +435,37 @@ def local_master_get_detection_dataset_dicts(*args, **kwargs):
@contextlib.contextmanager
def enable_disk_cached_dataset(cfg):
def configure_dataset_creation(cfg):
"""
Context manager for enabling disk cache datasets, this is a experimental feature.
- Replace D2's DatasetFromList with DiskCachedDatasetFromList, needs to patch all
call sites.
Context manager for configure settings used during dataset creating. It supports:
- offload the dataset to shared memory to reduce RAM usage.
- (experimental) offload the dataset to disk cache to further reduce RAM usage.
- Replace D2's get_detection_dataset_dicts with a local-master-only version.
"""
if not cfg.D2GO_DATA.DATASETS.DISK_CACHE.ENABLED:
yield
return
def _patched_dataset_from_list(lst, **kwargs):
from d2go.data.disk_cache import DiskCachedDatasetFromList
logger.info("Patch DatasetFromList with DiskCachedDatasetFromList")
return DiskCachedDatasetFromList(lst)
dataset_from_list_patch_locations = []
dataset_from_list_offload_method = SharedList # use SharedList by default
if cfg.D2GO_DATA.DATASETS.DISK_CACHE.ENABLED:
# delay the import to avoid atexit cleanup
from d2go.data.disk_cache import DiskCachedList
def _maybe_add_dataset_from_list_patch_location(module_name):
try:
__import__(module_name)
dataset_from_list_patch_locations.append(f"{module_name}.DatasetFromList")
except ImportError:
pass
dataset_from_list_offload_method = DiskCachedList
_maybe_add_dataset_from_list_patch_location("detectron2.data.build")
_maybe_add_dataset_from_list_patch_location("d2go.data.build")
_maybe_add_dataset_from_list_patch_location("d2go.data.build_fb")
_maybe_add_dataset_from_list_patch_location("d2go.data.build_oss")
load_dataset_from_local_master = cfg.D2GO_DATA.DATASETS.DISK_CACHE.ENABLED
with contextlib.ExitStack() as stack:
for ctx in [
ctx_managers = [
set_default_dataset_from_list_serialize_method(
dataset_from_list_offload_method
)
]
if load_dataset_from_local_master:
ctx_managers.append(
mock.patch(
"detectron2.data.build.get_detection_dataset_dicts",
side_effect=local_master_get_detection_dataset_dicts,
),
*[
mock.patch(m, side_effect=_patched_dataset_from_list)
for m in dataset_from_list_patch_locations
],
]:
)
)
for ctx in ctx_managers:
stack.enter_context(ctx)
yield
......@@ -18,7 +18,7 @@ from d2go.data.dataset_mappers import build_dataset_mapper
from d2go.data.datasets import inject_coco_datasets, register_dynamic_datasets
from d2go.data.transforms.build import build_transform_gen
from d2go.data.utils import (
enable_disk_cached_dataset,
configure_dataset_creation,
maybe_subsample_n_images,
update_cfg_if_using_adhoc_dataset,
)
......@@ -508,14 +508,14 @@ class Detectron2GoRunner(BaseRunner):
logger.info(
"Building detection test loader for dataset: {} ...".format(dataset_name)
)
with enable_disk_cached_dataset(cfg):
with configure_dataset_creation(cfg):
mapper = mapper or cls.get_mapper(cfg, is_train=False)
logger.info("Using dataset mapper:\n{}".format(mapper))
return d2_build_detection_test_loader(cfg, dataset_name, mapper=mapper)
@classmethod
def build_detection_train_loader(cls, cfg, *args, mapper=None, **kwargs):
with enable_disk_cached_dataset(cfg):
with configure_dataset_creation(cfg):
mapper = mapper or cls.get_mapper(cfg, is_train=True)
data_loader = build_d2go_train_loader(cfg, mapper)
return cls._attach_visualizer_to_data_loader(cfg, data_loader)
......
......@@ -8,8 +8,8 @@ import tempfile
import unittest
import torch
from d2go.data.disk_cache import DiskCachedDatasetFromList, ROOT_CACHE_DIR
from d2go.data.utils import enable_disk_cached_dataset
from d2go.data.disk_cache import DiskCachedList, ROOT_CACHE_DIR
from d2go.data.utils import configure_dataset_creation
from d2go.runner import create_runner
from d2go.utils.testing.data_loader_helper import (
create_detection_data_loader_on_toy_dataset,
......@@ -76,10 +76,10 @@ class TestDiskCachedDataLoader(unittest.TestCase):
return len(os.listdir(ROOT_CACHE_DIR))
def test_disk_cached_dataset_from_list(self):
"""Test the class of DiskCachedDatasetFromList"""
"""Test the class of DiskCachedList"""
# check the discache can handel different data types
lst = [1, torch.tensor(2), _MyClass(3)]
disk_cached_lst = DiskCachedDatasetFromList(lst)
disk_cached_lst = DiskCachedList(lst)
self.assertEqual(len(disk_cached_lst), 3)
self.assertEqual(disk_cached_lst[0], 1)
self.assertEqual(disk_cached_lst[1].item(), 2)
......@@ -109,7 +109,7 @@ class TestDiskCachedDataLoader(unittest.TestCase):
# enable the disk cache
cfg.merge_from_list(["D2GO_DATA.DATASETS.DISK_CACHE.ENABLED", "True"])
with enable_disk_cached_dataset(cfg):
with configure_dataset_creation(cfg):
# no cache dir in the beginning
self.assertEqual(self._count_cache_dirs(), 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment