"git@developer.sourcefind.cn:OpenDAS/torch-harmonics.git" did not exist on "67370881e7c6cfc4c31ca6e0eaf21f25318f4ea8"
Unverified Commit ef194cd2 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[feature] Add a OffloadConfig object to specify offloading params to disk. (#855)

* fixed lint issues

* remove unused print statements

* add changelog entry

* [skip ci] fix lint errors
parent 2bfa5a61
...@@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added ### Added
- Sharded Grad Scaler works with cpu offload in mixed and full precision. [#831] - Sharded Grad Scaler works with cpu offload in mixed and full precision. [#831]
- API for specifying SSD offload for params with FSDP. You can use a OffloadConfig to specify the type of offload
and the file path for storing params on SSD. Note: This is an experimental feature. [#855]
### Changed ### Changed
- Cleanup: Moving forward we would be testing all of our code with Python 3.9.7, CUDA 11.2 and the following three versions of PyTorch [#847]: - Cleanup: Moving forward we would be testing all of our code with Python 3.9.7, CUDA 11.2 and the following three versions of PyTorch [#847]:
......
...@@ -10,6 +10,7 @@ import tempfile ...@@ -10,6 +10,7 @@ import tempfile
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import torchtext import torchtext
from torchtext.data.utils import get_tokenizer from torchtext.data.utils import get_tokenizer
from torchtext.utils import download_from_url, extract_archive from torchtext.utils import download_from_url, extract_archive
......
...@@ -18,8 +18,6 @@ import torch.multiprocessing as mp ...@@ -18,8 +18,6 @@ import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import torchtext
from torchtext.data.utils import get_tokenizer
from fairscale.experimental.nn.ampnet_pipe import pipe from fairscale.experimental.nn.ampnet_pipe import pipe
from fairscale.nn.model_parallel import initialize_model_parallel from fairscale.nn.model_parallel import initialize_model_parallel
...@@ -27,6 +25,8 @@ from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group ...@@ -27,6 +25,8 @@ from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule from fairscale.nn.pipe import LazyModule
from fairscale.optim import GradScaler from fairscale.optim import GradScaler
from fairscale.utils.testing import dist_init, get_worker_map from fairscale.utils.testing import dist_init, get_worker_map
import torchtext
from torchtext.data.utils import get_tokenizer
try: try:
from fairscale.optim import Adam # type: ignore from fairscale.optim import Adam # type: ignore
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from typing import List from typing import List
from .fully_sharded_data_parallel import FullyShardedDataParallel, TrainingState, auto_wrap_bn from .fully_sharded_data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState, auto_wrap_bn
from .sharded_ddp import ShardedDataParallel from .sharded_ddp import ShardedDataParallel
__all__: List[str] = [] __all__: List[str] = []
...@@ -5,12 +5,13 @@ ...@@ -5,12 +5,13 @@
import contextlib import contextlib
import copy import copy
from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
import functools import functools
import logging import logging
from math import inf from math import inf
import os import os
from random import randint import tempfile
import time import time
import traceback import traceback
import typing import typing
...@@ -100,6 +101,19 @@ class TrainingState(Enum): ...@@ -100,6 +101,19 @@ class TrainingState(Enum):
SUMMON_FULL_PARAMS = auto() SUMMON_FULL_PARAMS = auto()
# Data classes containing FSDP parameter constructs
# Offload config for specifying SSD options (initially at least)
@dataclass
class OffloadConfig:
"""Class for specifying all arguments related to offloading parameters."""
# Offload type: currently only supports: "ssd_offload"
offload_type: str = None
# Path to the directory for storing parameters offloaded to disk.
ssd_filepath_dir: str = None
class FullyShardedDataParallel(nn.Module): class FullyShardedDataParallel(nn.Module):
""" """
A wrapper for sharding Module parameters across data parallel workers. This A wrapper for sharding Module parameters across data parallel workers. This
...@@ -260,6 +274,10 @@ class FullyShardedDataParallel(nn.Module): ...@@ -260,6 +274,10 @@ class FullyShardedDataParallel(nn.Module):
cpu_offload (bool, Optional): cpu_offload (bool, Optional):
if ``True``, offload params to CPU. Note: This arg will be deprecated in favor of if ``True``, offload params to CPU. Note: This arg will be deprecated in favor of
*``move_params_to_cpu``* in an upcoming release. *``move_params_to_cpu``* in an upcoming release.
offload_config (OffloadConfig):
The `OffloadConfig` object is used to specify the type of offload (i.e SSD, CPU) and
other required knobs when offloading parameters from GPU. Currently the OffloadConfig
only supports specifying SSD offload as an option. Note: This is an experimental feature.
""" """
def __init__( def __init__(
...@@ -282,7 +300,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -282,7 +300,7 @@ class FullyShardedDataParallel(nn.Module):
force_input_to_fp32: bool = False, force_input_to_fp32: bool = False,
verbose: bool = False, verbose: bool = False,
cpu_offload: bool = False, cpu_offload: bool = False,
**kwargs: Dict[str, Any], offload_config: OffloadConfig = None,
): ):
init_start = time.time() init_start = time.time()
super().__init__() super().__init__()
...@@ -306,7 +324,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -306,7 +324,7 @@ class FullyShardedDataParallel(nn.Module):
self.force_input_to_fp32 = force_input_to_fp32 self.force_input_to_fp32 = force_input_to_fp32
self.verbose = verbose self.verbose = verbose
# Experimental feature for now. Use at your own risk. # Experimental feature for now. Use at your own risk.
self.ssd_offload = kwargs.get("ssd_offload", False) self.ssd_offload = True if offload_config and offload_config.offload_type == "ssd_offload" else False
self.gradient_predivide_factor: float = self._get_gradient_predivide_factor(self.world_size) self.gradient_predivide_factor: float = self._get_gradient_predivide_factor(self.world_size)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
...@@ -339,12 +357,13 @@ class FullyShardedDataParallel(nn.Module): ...@@ -339,12 +357,13 @@ class FullyShardedDataParallel(nn.Module):
# TODO(anj): Should we conditionally do this only if we have params? # TODO(anj): Should we conditionally do this only if we have params?
# TODO(anj): Figure out if we can allocate the buffer during sharding. # TODO(anj): Figure out if we can allocate the buffer during sharding.
self.buffer_size = sum(p.numel() for p in params) self.buffer_size = sum(p.numel() for p in params)
self.ssd_buffer_filename = ""
if self.ssd_offload: if self.ssd_offload:
assert import_ssd_offload, "We need to import ssd_offload.py to enable the `ssd_offload` feature." assert import_ssd_offload, "We need to import ssd_offload.py to enable the `ssd_offload` feature."
# TODO(anj): Add support for temp file and directory as possible API params. self.ssd_buffer_filepath_dir = (
self.ssd_buffer_filename = f"{randint(1, int(10E6))}_rank{self.rank}" offload_config.ssd_filepath_dir if offload_config.ssd_filepath_dir else tempfile.gettempdir()
self.ssd_buffer = ssd_offload.SsdBuffer(self.buffer_size, self.ssd_buffer_filename) )
self.ssd_buffer_filename = tempfile.mkstemp(dir=self.ssd_buffer_filepath_dir)
self.ssd_buffer = ssd_offload.SsdBuffer(self.buffer_size, self.ssd_buffer_filename[1])
self.move_grads_to_cpu = True self.move_grads_to_cpu = True
self.move_params_to_cpu = True self.move_params_to_cpu = True
......
...@@ -27,4 +27,4 @@ use_parentheses = true ...@@ -27,4 +27,4 @@ use_parentheses = true
skip_glob = ["build/*", "stubs/*"] skip_glob = ["build/*", "stubs/*"]
# Don't split "import" and "from". # Don't split "import" and "from".
force_sort_within_sections = true force_sort_within_sections = true
known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "models", "numpy", "parameterized", "pytest", "recommonmark", "setuptools", "torch", "torchtext", "torchvision"] known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "models", "numpy", "parameterized", "pytest", "recommonmark", "setuptools", "torch", "torchvision"]
...@@ -4,10 +4,9 @@ ...@@ -4,10 +4,9 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import functools import functools
import glob
import itertools import itertools
import os
import sys import sys
import tempfile
import time import time
import unittest import unittest
...@@ -18,11 +17,12 @@ from torch import nn ...@@ -18,11 +17,12 @@ from torch import nn
import torch.distributed import torch.distributed
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel, TrainingState from fairscale.nn.data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState
from fairscale.utils import torch_version from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, rmf, spawn_for_all_world_sizes from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes
# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release. # Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
print(f"torch version {torch_version()}")
pytestmark = pytest.mark.skipif(torch_version() < (1, 11, 0), reason="requires torch version >= 1.11.0") pytestmark = pytest.mark.skipif(torch_version() < (1, 11, 0), reason="requires torch version >= 1.11.0")
...@@ -32,8 +32,6 @@ pytestmark = pytest.mark.skipif(torch_version() < (1, 11, 0), reason="requires t ...@@ -32,8 +32,6 @@ pytestmark = pytest.mark.skipif(torch_version() < (1, 11, 0), reason="requires t
class DistributedTest(unittest.TestCase): class DistributedTest(unittest.TestCase):
def setUp(self): def setUp(self):
if torch_version() < (1, 6, 0):
raise unittest.SkipTest("Need pytorch version >= 1.6 due to lack of reduce_scatter")
if not torch.cuda.is_available(): if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA not available, skipping test") raise unittest.SkipTest("CUDA not available, skipping test")
if sys.platform == "win32": if sys.platform == "win32":
...@@ -102,8 +100,12 @@ class DistributedTest(unittest.TestCase): ...@@ -102,8 +100,12 @@ class DistributedTest(unittest.TestCase):
ref_state_dict[k] = ref_state_dict[k].cpu() ref_state_dict[k] = ref_state_dict[k].cpu()
# Confirm we get the same behavior using FullyShardedDataParallel. # Confirm we get the same behavior using FullyShardedDataParallel.
if config.get("ssd_offload", False):
config["offload_config"] = OffloadConfig(offload_type="ssd_offload")
del config["ssd_offload"]
model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config) model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
if not config.get("ssd_offload", False): if not model.ssd_offload and not model.move_params_to_cpu:
if use_cuda: if use_cuda:
model = model.cuda() model = model.cuda()
else: else:
...@@ -149,17 +151,15 @@ class TestSsdMemory(DistributedTest): ...@@ -149,17 +151,15 @@ class TestSsdMemory(DistributedTest):
model = SimpleLinear(group, input_size=SIZE, output_size=SIZE, layers=4) model = SimpleLinear(group, input_size=SIZE, output_size=SIZE, layers=4)
time_keeper.print_time("CPU_MODEL", 1.0) time_keeper.print_time("CPU_MODEL", 1.0)
config["ssd_offload"] = True with tempfile.TemporaryDirectory() as current_tempdir:
config["offload_config"] = OffloadConfig(offload_type="ssd_offload", ssd_filepath_dir=current_tempdir)
model = FullyShardedDataParallel(model, **config) model = FullyShardedDataParallel(model, **config)
time_keeper.print_time("FSDP_MODEL", 1.0) time_keeper.print_time("FSDP_MODEL", 1.0)
self._eval_for_several_steps(model, 1, autocast=False) self._eval_for_several_steps(model, 1, autocast=False)
time_keeper.print_time("EVAL") time_keeper.print_time("EVAL")
fileList = glob.glob(os.getcwd() + "/*_rank*")
for file in fileList:
rmf(file)
class SimpleLinear(nn.Module): class SimpleLinear(nn.Module):
def __init__(self, group, input_size, output_size, layers=1, **unused_kwargs): def __init__(self, group, input_size, output_size, layers=1, **unused_kwargs):
...@@ -221,10 +221,14 @@ class TestModuleProperties(DistributedTest): ...@@ -221,10 +221,14 @@ class TestModuleProperties(DistributedTest):
before_wrap_model = TransformerWithSharedParams(group) before_wrap_model = TransformerWithSharedParams(group)
before_wrap_params = before_wrap_model.named_parameters() before_wrap_params = before_wrap_model.named_parameters()
config["ssd_offload"] = True with tempfile.TemporaryDirectory() as current_tempdir:
model = FullyShardedDataParallel(before_wrap_model, **config) if config["ssd_offload"]:
config["offload_config"] = OffloadConfig(offload_type="ssd_offload", ssd_filepath_dir=current_tempdir)
del config["ssd_offload"]
if not config["ssd_offload"]: model = FullyShardedDataParallel(before_wrap_model, **config)
print(f"model.ssd_offload {model.ssd_offload}")
if not model.ssd_offload and not model.move_params_to_cpu:
model = model.cuda() model = model.cuda()
self._eval_with_config(model, autocast=config["mixed_precision"]) self._eval_with_config(model, autocast=config["mixed_precision"])
...@@ -252,7 +256,7 @@ class TestSsdLoading(DistributedTest): ...@@ -252,7 +256,7 @@ class TestSsdLoading(DistributedTest):
test_fn = functools.partial(self._test_ssd_offload_eval, config=config) test_fn = functools.partial(self._test_ssd_offload_eval, config=config)
spawn_and_init(test_fn) spawn_and_init(test_fn)
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) @parameterized.expand(CONFIG, name_func=rename_test)
def test_transformer_parameterized(self, config): def test_transformer_parameterized(self, config):
spawn_and_init(functools.partial(self._test_identical_outputs_eval, TransformerWithSharedParams, config)) spawn_and_init(functools.partial(self._test_identical_outputs_eval, TransformerWithSharedParams, config))
...@@ -264,14 +268,15 @@ class TestSsdLoading(DistributedTest): ...@@ -264,14 +268,15 @@ class TestSsdLoading(DistributedTest):
nested_wrapping = config["nested_wrapping"] nested_wrapping = config["nested_wrapping"]
del config["nested_wrapping"] del config["nested_wrapping"]
config["ssd_offload"] = True with tempfile.TemporaryDirectory() as current_tempdir:
config["offload_config"] = OffloadConfig(offload_type="ssd_offload", ssd_filepath_dir=current_tempdir)
if nested_wrapping: if nested_wrapping:
model = FullyShardedDataParallel(NestedWrappedModule(group, wrap_everything=True, wrapper_config=config)) model = FullyShardedDataParallel(
NestedWrappedModule(group, wrap_everything=True, wrapper_config=config)
)
else: else:
model = FullyShardedDataParallel(model, **config) model = FullyShardedDataParallel(model, **config)
if not config["ssd_offload"]:
model = model.cuda()
self._eval_with_config(model, autocast=config["mixed_precision"]) self._eval_with_config(model, autocast=config["mixed_precision"])
# With SSD offload only local_state_dict will work. We can support global # With SSD offload only local_state_dict will work. We can support global
...@@ -281,10 +286,6 @@ class TestSsdLoading(DistributedTest): ...@@ -281,10 +286,6 @@ class TestSsdLoading(DistributedTest):
self._eval_with_config(model, config["mixed_precision"]) self._eval_with_config(model, config["mixed_precision"])
fileList = glob.glob(os.getcwd() + "/*_rank*")
for file in fileList:
rmf(file)
class TransformerWithSharedParams(nn.Module): class TransformerWithSharedParams(nn.Module):
def __init__(self, group, *unused_args, d_vocab=23, d_model=16, add_bn=True, **unused_kwargs): def __init__(self, group, *unused_args, d_vocab=23, d_model=16, add_bn=True, **unused_kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment