Unverified Commit e71d2570 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[cleanup] remove ssd offload to simplify the FSDP code (#1080)



* simlificed the readme

* clean up ssd offload

* try to fix readthedocs
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent f4fcee7e
# .readthedocs.yaml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
# Required
version: 2
# We need python > 3.8 due to a dependency on numpy.
build:
os: ubuntu-20.04
tools:
python: "3.9"
# You can also specify other tool versions:
# nodejs: "16"
# rust: "1.55"
# golang: "1.17"
# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/source/conf.py
# If using Sphinx, optionally build your docs in additional formats such as PDF
# formats:
# - pdf
# Optionally declare the Python requirements required to build your docs
python:
install:
- requirements: docs/requirements.txt
...@@ -25,23 +25,6 @@ FairScale was designed with the following values in mind: ...@@ -25,23 +25,6 @@ FairScale was designed with the following values in mind:
[![Explain Like I’m 5: FairScale](https://img.youtube.com/vi/oDt7ebOwWIc/0.jpg)](https://www.youtube.com/watch?v=oDt7ebOwWIc) [![Explain Like I’m 5: FairScale](https://img.youtube.com/vi/oDt7ebOwWIc/0.jpg)](https://www.youtube.com/watch?v=oDt7ebOwWIc)
## What's New:
* March 2022 [fairscale 0.4.6 was released](https://github.com/facebookresearch/fairscale/releases/tag/v0.4.6).
* We have support for CosFace's LMCL in MEVO. This is a loss function that is suitable for large number of prediction target classes.
* January 2022 [fairscale 0.4.5 was released](https://github.com/facebookresearch/fairscale/releases/tag/v0.4.5).
* We have experimental support for layer wise gradient scaling.
* We enabled reduce_scatter operation overlapping in FSDP backward propagation.
* December 2021 [fairscale 0.4.4 was released](https://github.com/facebookresearch/fairscale/releases/tag/v0.4.4).
* FairScale is tested with the following PyTorch versions (with CUDA 11.2): 1.8.1, 1.10.0 and 1.11.0.dev20211101+cu111.
* November 2021 [fairscale 0.4.3 was released](https://github.com/facebookresearch/fairscale/releases/tag/v0.4.3).
* We have experimental support for offloading params to disk when using the FSDP API for evaluation workloads.
* We have an experimental layer that fuses multiple layers together to support large vocab size trainings.
* November 2021 [fairscale 0.4.2 was released](https://github.com/facebookresearch/fairscale/releases/tag/v0.4.2).
* We have a new experimental API called the LayerwiseMemoryTracker to help track, visualize and suggest fixes for memory issues occurring during the forward/backward pass of your models.
* Introducing SlowMoDistributedDataParallel API, a distributed training wrapper that is useful on clusters with slow network interconnects (e.g. Ethernet).
* September 2021 [`master` branch renamed to `main`](https://github.com/github/renaming).
## Installation ## Installation
To install FairScale, please see the following [instructions](https://github.com/facebookresearch/fairscale/blob/main/docs/source/installation_instructions.rst). To install FairScale, please see the following [instructions](https://github.com/facebookresearch/fairscale/blob/main/docs/source/installation_instructions.rst).
...@@ -50,134 +33,26 @@ You should be able to install a package with pip or conda, or build directly fro ...@@ -50,134 +33,26 @@ You should be able to install a package with pip or conda, or build directly fro
## Getting Started ## Getting Started
The full [documentation](https://fairscale.readthedocs.io/) contains instructions for getting started, deep dives and tutorials about the various FairScale APIs. The full [documentation](https://fairscale.readthedocs.io/) contains instructions for getting started, deep dives and tutorials about the various FairScale APIs.
## Examples ## FSDP
Here are a few sample snippets from a subset of FairScale offerings:
### Pipe
Run a 4-layer model on 2 GPUs. The first two layers run on cuda:0 and the next two layers run on cuda:1.
```python
import torch
import fairscale
model = torch.nn.Sequential(a, b, c, d)
model = fairscale.nn.Pipe(model, balance=[2, 2], devices=[0, 1], chunks=8)
```
### Optimizer state sharding (ZeRO)
See a more complete example [here](https://github.com/facebookresearch/fairscale/blob/main/benchmarks/oss.py), but a minimal example could look like the following :
```python
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from fairscale.optim.oss import OSS
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
def train(
rank: int,
world_size: int,
epochs: int):
# DDP init example
dist.init_process_group(backend='nccl', init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
# Problem statement
model = myAwesomeModel().to(rank)
dataloader = mySuperFastDataloader()
loss_fn = myVeryRelevantLoss()
base_optimizer = torch.optim.SGD # pick any pytorch compliant optimizer here
base_optimizer_arguments = {} # pass any optimizer specific arguments here, or directly below when instantiating OSS
# Wrap the optimizer in its state sharding brethren
optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments)
# Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks
model = ShardedDDP(model, optimizer)
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
for e in range(epochs):
for batch in dataloader:
# Train
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
loss.backward()
optimizer.step()
dist.destroy_process_group()
if __name__ == "__main__":
# Supposing that WORLD_SIZE and EPOCHS are somehow defined somewhere
mp.spawn(
train,
args=(
WORLD_SIZE,
EPOCHS,
),
nprocs=WORLD_SIZE,
join=True,
)
```
### AdaScale SGD
AdaScale can be used to wrap a SGD optimizer and to be used in DDP (Distributed Data Parallel)
training or non-DDP with gradient accumulation. The benefit is to re-use the same LR
schedule from a baseline batch size when effective batch size is bigger.
Note that AdaScale does _not_ help increase per-GPU batch size.
```python
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR # or your scheduler
from fairscale.optim import AdaScale
...
optim = AdaScale(SGD(model.parameters(), lr=0.1))
scheduler = LambdaLR(optim, ...)
...
# Note: the train loop should be with DDP or with gradient accumulation.
last_epoch = 0
step = 0
done = False
while not done:
for sample in dataset:
...
step += optim.gain()
optim.step()
epoch = step // len(dataset)
if last_epoch != epoch:
scheduler.step()
last_epoch = epoch
if epoch > max_epoch:
done = True
```
Primary goal is to allow scaling to bigger batch sizes without losing model accuracy. FullyShardedDataParallel (FSDP) is the recommended method for scaling to large NN models.
(However, training time might be longer comparing to without AdaScale.) This library has been [upstreamed to PyTorch](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/).
The version of FSDP here is for historical references as well as for experimenting with
At a high level, we want ML researchers to: new and crazy ideas in research of scaling techniques. Please see the following blog
* go parallel more easily (i.e. no need to find new learning rate schedules) for [how to use FairScale FSDP and how does it work](https://engineering.fb.com/2021/07/15/open-source/fsdp/).
* not worrying about losing accuracy
* potentially higher GPU efficiency (fewer steps, less networking overhead, etc.)
## Testing ## Testing
We use circleci to test FairScale with the following PyTorch versions (with CUDA 11.2): We use circleci to test FairScale with the following PyTorch versions (with CUDA 11.2):
* the latest stable release (1.10.0) * the latest stable release (e.g. 1.10.0)
* the latest LTS release (1.8.1) * the latest LTS release (e.g. 1.8.1)
* a recent nightly release (1.11.0.dev20211101+cu111) * a recent nightly release (e.g. 1.11.0.dev20211101+cu111)
Please create an [issue](https://github.com/facebookresearch/fairscale/issues) if you are having trouble with installation. Please create an [issue](https://github.com/facebookresearch/fairscale/issues) if you are having trouble with installation.
## Contributors ## Contributors
We welcome outside contributions! Please see the [CONTRIBUTING](CONTRIBUTING.md) instructions for how you can contribute to FairScale. We welcome contributions! Please see the [CONTRIBUTING](CONTRIBUTING.md) instructions for how you can contribute to FairScale.
## License ## License
...@@ -198,22 +73,9 @@ If you use FairScale in your publication, please cite it by using the following ...@@ -198,22 +73,9 @@ If you use FairScale in your publication, please cite it by using the following
```BibTeX ```BibTeX
@Misc{FairScale2021, @Misc{FairScale2021,
author = {Mandeep Baines and Shruti Bhosale and Vittorio Caggiano and Naman Goyal and Siddharth Goyal and Myle Ott and Benjamin Lefaudeux and Vitaliy Liptchinsky and Mike Rabbat and Sam Sheiffer and Anjali Sridhar and Min Xu}, author = {FairScale authors},
title = {FairScale: A general purpose modular PyTorch library for high performance and large scale training}, title = {FairScale: A general purpose modular PyTorch library for high performance and large scale training},
howpublished = {\url{https://github.com/facebookresearch/fairscale}}, howpublished = {\url{https://github.com/facebookresearch/fairscale}},
year = {2021} year = {2021}
} }
``` ```
## FAQ
1. If you experience an error indicating a default branch does not exist, it probably due to the latest update, switching the default branch from "master" to "main"
```
error: pathspec 'non-existing-branch' did not match any file(s) known to git
```
Please run the following commands to update to the main branch.
```
git branch -m master main
git fetch origin
git branch -u origin/main main
git remote set-head origin -a
```
...@@ -25,7 +25,6 @@ from torch.optim import Adam ...@@ -25,7 +25,6 @@ from torch.optim import Adam
from benchmarks.golden_configs.lm_wikitext2 import FSDP as lm_wikitext2 from benchmarks.golden_configs.lm_wikitext2 import FSDP as lm_wikitext2
from fairscale.nn import auto_wrap, default_auto_wrap_policy, enable_wrap from fairscale.nn import auto_wrap, default_auto_wrap_policy, enable_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import OffloadConfig
RPC_PORT = 29501 RPC_PORT = 29501
...@@ -95,10 +94,7 @@ def get_lm_model(args, device, config): ...@@ -95,10 +94,7 @@ def get_lm_model(args, device, config):
nhid = config["nhid"] nhid = config["nhid"]
ndecoder = config["num_decoder_layers"] ndecoder = config["num_decoder_layers"]
if args.ssd_offload: return transformer_lm.TransformerLM(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
return transformer_lm.TransformerLM(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder)
else:
return transformer_lm.TransformerLM(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
def get_tensors_by_size_bucket(): def get_tensors_by_size_bucket():
...@@ -200,7 +196,7 @@ def train(model_config, model, benchmark_config, model_specs, args): ...@@ -200,7 +196,7 @@ def train(model_config, model, benchmark_config, model_specs, args):
if i > 0: if i > 0:
total_tokens += source.numel() total_tokens += source.numel()
if args.benchmark_eval or args.ssd_offload: if args.benchmark_eval:
input = source.cuda() input = source.cuda()
target = target.cuda() target = target.cuda()
output = model(input) output = model(input)
...@@ -250,7 +246,6 @@ def get_number_of_words(data): ...@@ -250,7 +246,6 @@ def get_number_of_words(data):
def benchmark_language_model(model_config, model, benchmark_config, model_specs, args): def benchmark_language_model(model_config, model, benchmark_config, model_specs, args):
# TODO(anj): Uncomment and add a check for regression once we have a couple of runs.
golden_config = get_golden_config(args.model_name, args) golden_config = get_golden_config(args.model_name, args)
epoch = benchmark_config["epochs"] epoch = benchmark_config["epochs"]
start_time = time.time() start_time = time.time()
...@@ -358,8 +353,6 @@ def benchmark_fsdp(rank, args, world_size): ...@@ -358,8 +353,6 @@ def benchmark_fsdp(rank, args, world_size):
model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs) model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
model = model_config["model"] model = model_config["model"]
config = {} config = {}
if args.ssd_offload:
config["offload_config"] = OffloadConfig(offload_type="ssd_offload")
if args.full_fp16: if args.full_fp16:
config["compute_dtype"] = torch.float16 config["compute_dtype"] = torch.float16
...@@ -386,7 +379,6 @@ parser.add_argument("--max_batch", type=int, default=4, help="Max number of batc ...@@ -386,7 +379,6 @@ parser.add_argument("--max_batch", type=int, default=4, help="Max number of batc
parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.") parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.")
parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.") parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.")
parser.add_argument( parser.add_argument(
# TODO(anj-s): In the process of adding more models and hence the requirement for a flag.
"--model_name", "--model_name",
default="lm", default="lm",
help="Language Model(LM) used to benchmark FSDP.", help="Language Model(LM) used to benchmark FSDP.",
...@@ -394,7 +386,6 @@ parser.add_argument( ...@@ -394,7 +386,6 @@ parser.add_argument(
parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information") parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
parser.add_argument("--enable_auto_wrap", action="store_true", default=False, help="Use auto_wrap with FSDP") parser.add_argument("--enable_auto_wrap", action="store_true", default=False, help="Use auto_wrap with FSDP")
parser.add_argument("--benchmark_eval", action="store_true", default=False, help="Benchmark evaluation workflow.") parser.add_argument("--benchmark_eval", action="store_true", default=False, help="Benchmark evaluation workflow.")
parser.add_argument("--ssd_offload", action="store_true", default=False, help="Benchmark ssd_offload workflow.")
parser.add_argument("--full_fp16", action="store_true", default=False, help="Benchmark in full fp16 mode.") parser.add_argument("--full_fp16", action="store_true", default=False, help="Benchmark in full fp16 mode.")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -30,7 +30,7 @@ sys.path.insert(0, os.path.abspath("../..")) ...@@ -30,7 +30,7 @@ sys.path.insert(0, os.path.abspath("../.."))
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
project = "FairScale" project = "FairScale"
copyright = "2020-2021, Facebook/Meta AI Research" copyright = "2020-2022, Facebook/Meta AI Research"
author = "Facebook/Meta AI Research" author = "Facebook/Meta AI Research"
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
...@@ -68,7 +68,7 @@ autodoc_inherit_docstrings = False ...@@ -68,7 +68,7 @@ autodoc_inherit_docstrings = False
autodoc_member_order = "bysource" autodoc_member_order = "bysource"
intersphinx_mapping = { intersphinx_mapping = {
"python": ("https://docs.python.org/3.6", None), "python": ("https://docs.python.org/3.8", None),
"numpy": ("https://numpy.org/doc/stable/", None), "numpy": ("https://numpy.org/doc/stable/", None),
"torch": ("https://pytorch.org/docs/stable/", None), "torch": ("https://pytorch.org/docs/stable/", None),
} }
......
This diff is collapsed.
...@@ -9,7 +9,6 @@ import torch.distributed as dist ...@@ -9,7 +9,6 @@ import torch.distributed as dist
from .fully_sharded_data_parallel import ( from .fully_sharded_data_parallel import (
FullyShardedDataParallel, FullyShardedDataParallel,
OffloadConfig,
TrainingState, TrainingState,
auto_wrap_bn, auto_wrap_bn,
get_fsdp_instances, get_fsdp_instances,
......
...@@ -5,13 +5,11 @@ ...@@ -5,13 +5,11 @@
import contextlib import contextlib
import copy import copy
from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
import functools import functools
import logging import logging
from math import inf from math import inf
import os import os
import tempfile
import time import time
import traceback import traceback
import typing import typing
...@@ -69,15 +67,6 @@ if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0": ...@@ -69,15 +67,6 @@ if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
else: else:
enable_nccl_base_collectives = True enable_nccl_base_collectives = True
try:
import fairscale.experimental.nn.ssd_offload as ssd_offload
import_ssd_offload = True
except ImportError:
# The latest nightly PyTorch version required
import_ssd_offload = False
pass
class TrainingState(Enum): class TrainingState(Enum):
""" """
...@@ -107,19 +96,6 @@ class TrainingState(Enum): ...@@ -107,19 +96,6 @@ class TrainingState(Enum):
SUMMON_FULL_PARAMS = auto() SUMMON_FULL_PARAMS = auto()
# Data classes containing FSDP parameter constructs
# Offload config for specifying SSD options (initially at least)
@dataclass
class OffloadConfig:
"""Class for specifying all arguments related to offloading parameters."""
# Offload type: currently only supports: "ssd_offload"
offload_type: Optional[str] = None
# Path to the directory for storing parameters offloaded to disk.
dir: Optional[str] = None
class FullyShardedDataParallel(nn.Module): class FullyShardedDataParallel(nn.Module):
""" """
A wrapper for sharding Module parameters across data parallel workers. This A wrapper for sharding Module parameters across data parallel workers. This
...@@ -302,10 +278,6 @@ class FullyShardedDataParallel(nn.Module): ...@@ -302,10 +278,6 @@ class FullyShardedDataParallel(nn.Module):
cpu_offload (bool, Optional): cpu_offload (bool, Optional):
if ``True``, offload params to CPU. Note: This arg will be deprecated in favor of if ``True``, offload params to CPU. Note: This arg will be deprecated in favor of
*``move_params_to_cpu``* in an upcoming release. *``move_params_to_cpu``* in an upcoming release.
offload_config (OffloadConfig):
The `OffloadConfig` object is used to specify the type of offload (i.e SSD, CPU) and
other required knobs when offloading parameters from GPU. Currently the OffloadConfig
only supports specifying SSD offload as an option. Note: This is an experimental feature.
state_dict_on_rank_0_only (bool): state_dict_on_rank_0_only (bool):
When set to ``True``, ``model.state_dict()`` will only returns full state dict on When set to ``True``, ``model.state_dict()`` will only returns full state dict on
rank 0 and return empty dict non-rank 0, which allow FullyShardedDataParallel to rank 0 and return empty dict non-rank 0, which allow FullyShardedDataParallel to
...@@ -342,7 +314,6 @@ class FullyShardedDataParallel(nn.Module): ...@@ -342,7 +314,6 @@ class FullyShardedDataParallel(nn.Module):
force_input_to_fp32: bool = False, force_input_to_fp32: bool = False,
verbose: bool = False, verbose: bool = False,
cpu_offload: bool = False, cpu_offload: bool = False,
offload_config: Optional[OffloadConfig] = None,
state_dict_on_rank_0_only: bool = False, state_dict_on_rank_0_only: bool = False,
gradient_predivide_factor: Optional[float] = None, gradient_predivide_factor: Optional[float] = None,
allow_reset_parameters: bool = False, allow_reset_parameters: bool = False,
...@@ -414,12 +385,6 @@ class FullyShardedDataParallel(nn.Module): ...@@ -414,12 +385,6 @@ class FullyShardedDataParallel(nn.Module):
self.force_input_to_fp32 = force_input_to_fp32 self.force_input_to_fp32 = force_input_to_fp32
self.verbose = verbose self.verbose = verbose
self.state_dict_on_rank_0_only = state_dict_on_rank_0_only self.state_dict_on_rank_0_only = state_dict_on_rank_0_only
# Experimental feature for now. Use at your own risk.
self.ssd_offload = True if offload_config and offload_config.offload_type == "ssd_offload" else False
if self.ssd_offload and not import_ssd_offload:
raise ImportError(
f"Trying to enable ssd_offload when it was not successfully imported (likely due to old torch version, current = {torch.__version__})"
)
self.gradient_predivide_factor: float = gradient_predivide_factor or self._get_gradient_predivide_factor( self.gradient_predivide_factor: float = gradient_predivide_factor or self._get_gradient_predivide_factor(
self.world_size self.world_size
...@@ -433,9 +398,6 @@ class FullyShardedDataParallel(nn.Module): ...@@ -433,9 +398,6 @@ class FullyShardedDataParallel(nn.Module):
if self.fp32_reduce_scatter and not self.mixed_precision: if self.fp32_reduce_scatter and not self.mixed_precision:
raise ValueError("fp32_reduce_scatter requires mixed_precision=True") raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
if self.ssd_offload and not self.flatten_parameters:
raise ValueError(f"offload type: '{offload_config.offload_type}' requires flatten_parameters=True")
# skip validation if the process group was created above # skip validation if the process group was created above
if process_group: if process_group:
validate_process_group(self.compute_device, self.process_group) validate_process_group(self.compute_device, self.process_group)
...@@ -456,16 +418,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -456,16 +418,7 @@ class FullyShardedDataParallel(nn.Module):
self._has_params = len(params) > 0 self._has_params = len(params) > 0
self._has_shared_params = False self._has_shared_params = False
# TODO(anj): Should we conditionally do this only if we have params?
# TODO(anj): Figure out if we can allocate the buffer during sharding.
self.buffer_size = sum(p.numel() for p in params) self.buffer_size = sum(p.numel() for p in params)
self.ssd_directory = tempfile.gettempdir()
if self.ssd_offload:
assert import_ssd_offload, "We need to import ssd_offload.py to enable the `ssd_offload` feature."
if offload_config and offload_config.dir:
self.ssd_directory = offload_config.dir
self.move_grads_to_cpu = True
self.move_params_to_cpu = True
# For now, it is either all flatten or none flatten. This will be extended to # For now, it is either all flatten or none flatten. This will be extended to
# multiple flatten groups in my next PR. # multiple flatten groups in my next PR.
...@@ -478,9 +431,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -478,9 +431,7 @@ class FullyShardedDataParallel(nn.Module):
param_name_groups = [param_names] param_name_groups = [param_names]
del param_names del param_names
self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper( self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(module, param_list=to_be_flatten_params)
module, param_list=to_be_flatten_params, ssd_offload=self.ssd_offload, ssd_directory=self.ssd_directory
)
del module # free original module in case it helps garbage collection del module # free original module in case it helps garbage collection
# Now, in this FSDP wrapper class, we keep a list of to-be-flatten and not-to-be-flatten # Now, in this FSDP wrapper class, we keep a list of to-be-flatten and not-to-be-flatten
...@@ -531,8 +482,6 @@ class FullyShardedDataParallel(nn.Module): ...@@ -531,8 +482,6 @@ class FullyShardedDataParallel(nn.Module):
# Flag to indicate whether state_dict() should automatically summon the # Flag to indicate whether state_dict() should automatically summon the
# full params. This defaults to True, but may be set to False if the # full params. This defaults to True, but may be set to False if the
# user explicitly requests the local state dict via local_state_dict(). # user explicitly requests the local state dict via local_state_dict().
# TODO(anj): This should by default be set to False for ssd_offload=True
# unless we are in the summon_full_params context.
self._return_full_state_dict = True self._return_full_state_dict = True
init_end = time.time() init_end = time.time()
...@@ -544,11 +493,6 @@ class FullyShardedDataParallel(nn.Module): ...@@ -544,11 +493,6 @@ class FullyShardedDataParallel(nn.Module):
# This is reset at the end of the backward pass. # This is reset at the end of the backward pass.
self._pre_backward_hook_has_run = False self._pre_backward_hook_has_run = False
# Free all params at the end of initialization.
if self.ssd_offload:
for m in get_fsdp_instances(self):
m._free_ssd_offload()
def _get_gradient_predivide_factor(self, world_size: int) -> float: def _get_gradient_predivide_factor(self, world_size: int) -> float:
factor: int = 1 factor: int = 1
while world_size % factor == 0 and world_size / factor > factor: while world_size % factor == 0 and world_size / factor > factor:
...@@ -785,10 +729,9 @@ class FullyShardedDataParallel(nn.Module): ...@@ -785,10 +729,9 @@ class FullyShardedDataParallel(nn.Module):
p._orig_size = p.data.size() p._orig_size = p.data.size()
if not p._is_sharded: if not p._is_sharded:
if not self.ssd_offload: p._is_sharded = False
p._is_sharded = False self.numel_padded_per_param.append(0)
self.numel_padded_per_param.append(0) continue
continue
p._is_sharded = True p._is_sharded = True
# TODO (Min): broadcast from rank 0 to avoid each rank need to init with the same seed? # TODO (Min): broadcast from rank 0 to avoid each rank need to init with the same seed?
...@@ -797,11 +740,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -797,11 +740,7 @@ class FullyShardedDataParallel(nn.Module):
p.data, num_padded = self._get_shard(p.data) p.data, num_padded = self._get_shard(p.data)
self.numel_padded_per_param.append(num_padded) self.numel_padded_per_param.append(num_padded)
if self.ssd_offload: free_storage_(orig_data)
assert isinstance(p, ssd_offload.SsdParameter)
p.to_file()
else:
free_storage_(orig_data)
assert len(self.numel_padded_per_param) == len(self.params) assert len(self.numel_padded_per_param) == len(self.params)
...@@ -1014,21 +953,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1014,21 +953,11 @@ class FullyShardedDataParallel(nn.Module):
backup = self._return_full_state_dict backup = self._return_full_state_dict
self._return_full_state_dict = False self._return_full_state_dict = False
if self.ssd_offload:
# Move params from disk to memory before returning the local state dict.
self._move_params_to_memory()
try: try:
yield yield
finally: finally:
self._return_full_state_dict = backup self._return_full_state_dict = backup
def _move_params_to_memory(self) -> None:
"""Move params from disk to CPU."""
for p in self.params:
assert isinstance(p, ssd_offload.SsdParameter)
p.to_tensor()
def _load_state_dict( def _load_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
) -> NamedTuple: ) -> NamedTuple:
...@@ -1276,22 +1205,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1276,22 +1205,11 @@ class FullyShardedDataParallel(nn.Module):
# Copy any changes made to the full params back into # Copy any changes made to the full params back into
# the corresponding local shards. # the corresponding local shards.
local_shard, _ = self._get_shard(full_tensor) local_shard, _ = self._get_shard(full_tensor)
if self.ssd_offload: p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
assert isinstance(p, ssd_offload.SsdParameter)
self._ssd_offload_reset_param_device(p)
p.point_to_tensor(local_shard.view_as(p._fp32_shard).cpu())
else:
p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
if safe_to_free: if safe_to_free:
free_storage_(full_tensor) free_storage_(full_tensor)
self.has_full_params = False self.has_full_params = False
if self.ssd_offload: self._use_fp32_param_shard()
# Store tensors in the SSD buffer and free param storage.
for p in self.params:
assert isinstance(p, ssd_offload.SsdParameter)
p.to_file()
else:
self._use_fp32_param_shard()
self.training_state = TrainingState.IDLE self.training_state = TrainingState.IDLE
def _reset_lazy_init(self) -> None: def _reset_lazy_init(self) -> None:
...@@ -1366,11 +1284,6 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1366,11 +1284,6 @@ class FullyShardedDataParallel(nn.Module):
return return
# A single shard of the parameters in full precision. # A single shard of the parameters in full precision.
# TODO(another-pjohnson) - I believe this will cause memory leakage with ssd
# p.data returns a pointer to a handle, and that handle has it's
# ref count incremented by p._fp32_shard. So this tensor will
# never be freed even if we do p.to_disk(). investigate after
# PR #887 is merged
p._fp32_shard = p.data p._fp32_shard = p.data
if self.mixed_precision: if self.mixed_precision:
...@@ -1378,14 +1291,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1378,14 +1291,11 @@ class FullyShardedDataParallel(nn.Module):
if self.move_params_to_cpu: if self.move_params_to_cpu:
assert p._fp32_shard.device == torch.device("cpu"), self assert p._fp32_shard.device == torch.device("cpu"), self
# We don't pin memory when using ssd_offload since that results in OOM when # If we plan to keep the FP32 parameters on CPU, then pinning
# the memory requirements of a model are larger than host memory. # memory allows us to later use non-blocking transfers when moving
if not self.ssd_offload: # the FP32 param shard to compute_device.
# If we plan to keep the FP32 parameters on CPU, then pinning p._fp32_shard = p._fp32_shard.pin_memory()
# memory allows us to later use non-blocking transfers when moving p.data = p._fp32_shard
# the FP32 param shard to compute_device.
p._fp32_shard = p._fp32_shard.pin_memory()
p.data = p._fp32_shard
if self.move_params_to_cpu or self.mixed_precision: if self.move_params_to_cpu or self.mixed_precision:
...@@ -1423,16 +1333,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1423,16 +1333,7 @@ class FullyShardedDataParallel(nn.Module):
# pass. In this case, it's important to pre-allocate the CPU grad # pass. In this case, it's important to pre-allocate the CPU grad
# shard in pinned memory so that we can do a non-blocking transfer. # shard in pinned memory so that we can do a non-blocking transfer.
# This is only needed during training and not evaluation. # This is only needed during training and not evaluation.
if self.ssd_offload: p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory()
assert isinstance(p, ssd_offload.SsdParameter)
# Gradients also need to be offloaded to SSD otherwise it can result in
# OOMs when the memory requirements of a model are larger than host memory.
p._cpu_grad = ssd_offload.SsdTensorHandle.from_tensor(torch.zeros_like(p.data, device="cpu"))
p._cpu_grad.allow_unsafe_changes = True
p._cpu_grad.set_file_params(p.filename + "_grad", 0)
p._cpu_grad.to_file()
else:
p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory()
def _set_is_root(self) -> None: def _set_is_root(self) -> None:
"""If ``True``, implies that no other :class:`FullyShardedDataParallel` """If ``True``, implies that no other :class:`FullyShardedDataParallel`
...@@ -1576,17 +1477,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1576,17 +1477,8 @@ class FullyShardedDataParallel(nn.Module):
if self.clear_autocast_cache: if self.clear_autocast_cache:
torch.clear_autocast_cache() torch.clear_autocast_cache()
self._free_ssd_offload()
return outputs return outputs
@torch.no_grad()
def _free_ssd_offload(self) -> None:
if self.ssd_offload:
for p in self.params:
assert isinstance(p, ssd_offload.SsdParameter)
p.to_file(permit_when_tensor_none=True)
def _register_pre_backward_hooks(self, outputs: Any) -> Any: def _register_pre_backward_hooks(self, outputs: Any) -> Any:
"""Register pre-backward hook to run before the wrapped module's """Register pre-backward hook to run before the wrapped module's
backward. Hooks should be attached to all outputs from the forward. backward. Hooks should be attached to all outputs from the forward.
...@@ -1990,7 +1882,6 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1990,7 +1882,6 @@ class FullyShardedDataParallel(nn.Module):
# Update root and nested FSDP's hooks and flags. # Update root and nested FSDP's hooks and flags.
for m in get_fsdp_instances(self): for m in get_fsdp_instances(self):
_finalize_parameters(m) _finalize_parameters(m)
m._free_ssd_offload()
m._pre_backward_hook_has_run = False m._pre_backward_hook_has_run = False
if any(p.requires_grad for p in m.parameters()): if any(p.requires_grad for p in m.parameters()):
# Check if the module has params and if any of them has # Check if the module has params and if any of them has
...@@ -2071,15 +1962,6 @@ class FullyShardedDataParallel(nn.Module): ...@@ -2071,15 +1962,6 @@ class FullyShardedDataParallel(nn.Module):
# Trim any padding and reshape to match original size. # Trim any padding and reshape to match original size.
p.data = p.data[: p._orig_size.numel()].view(p._orig_size) p.data = p.data[: p._orig_size.numel()].view(p._orig_size)
if self.ssd_offload:
for p in self.params:
assert isinstance(p, ssd_offload.SsdParameter)
if not p.is_available():
self._ssd_offload_reset_param_device(p)
p.to_tensor()
self.has_full_params = False
if self._has_shared_params: if self._has_shared_params:
# self.has_full_params flag can be out of sync if a shared param is # self.has_full_params flag can be out of sync if a shared param is
# sharded by another FSDP instance. An example is that in eval case # sharded by another FSDP instance. An example is that in eval case
...@@ -2366,25 +2248,13 @@ class FullyShardedDataParallel(nn.Module): ...@@ -2366,25 +2248,13 @@ class FullyShardedDataParallel(nn.Module):
return consolidated_weights return consolidated_weights
@torch.no_grad()
def _ssd_offload_reset_param_device(self, param: Parameter) -> None:
assert isinstance(param, ssd_offload.SsdParameter)
if param.device != torch.device("cpu"):
param.data = param._fp32_shard
param.tensor = None
@torch.no_grad() @torch.no_grad()
def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None: def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
"""Use FP32 shard for a list of params.""" """Use FP32 shard for a list of params."""
if params is None: if params is None:
params = self.params params = self.params
for p in params: for p in params:
if import_ssd_offload and self.ssd_offload: p.data = p._fp32_shard
assert isinstance(p, ssd_offload.SsdParameter)
self._ssd_offload_reset_param_device(p)
p.to_tensor()
else:
p.data = p._fp32_shard
@torch.no_grad() @torch.no_grad()
def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = None) -> None: def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = None) -> None:
...@@ -2395,14 +2265,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -2395,14 +2265,11 @@ class FullyShardedDataParallel(nn.Module):
for p in params: for p in params:
assert p._fp16_shard is not None assert p._fp16_shard is not None
alloc_storage_(p._fp16_shard, size=p._fp32_shard.size()) alloc_storage_(p._fp16_shard, size=p._fp32_shard.size())
if self.ssd_offload: p._fp16_shard.copy_(
p._fp16_shard.copy_(p.to(p._fp16_shard.device, non_blocking=True)) # If move_params_to_cpu is True, this will be non-blocking
else: # because _fp32_shard is pinned, otherwise it's a no-op.
p._fp16_shard.copy_( p._fp32_shard.to(p._fp16_shard.device, non_blocking=True)
# If move_params_to_cpu is True, this will be non-blocking )
# because _fp32_shard is pinned, otherwise it's a no-op.
p._fp32_shard.to(p._fp16_shard.device, non_blocking=True)
)
p.data = p._fp16_shard p.data = p._fp16_shard
torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"]) torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"])
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
from contextlib import contextmanager from contextlib import contextmanager
from itertools import chain from itertools import chain
import tempfile
import typing import typing
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
...@@ -31,19 +30,6 @@ import torch ...@@ -31,19 +30,6 @@ import torch
from torch import Tensor from torch import Tensor
import torch.nn as nn import torch.nn as nn
try:
from fairscale.experimental.nn.ssd_offload import (
SsdFlatParameter,
SsdFlatParameterView,
SsdFlatParameterViewProperty,
_register_property,
)
import_ssd_offload = True
except ImportError:
import_ssd_offload = False
pass
from fairscale.internal.state_dict import replace_by_prefix_ from fairscale.internal.state_dict import replace_by_prefix_
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -169,15 +155,8 @@ class FlattenParamsWrapper(nn.Module): ...@@ -169,15 +155,8 @@ class FlattenParamsWrapper(nn.Module):
module: nn.Module, module: nn.Module,
param_list: ParamGroups = None, param_list: ParamGroups = None,
flat_param_names: Optional[List[str]] = None, flat_param_names: Optional[List[str]] = None,
ssd_offload: bool = False,
ssd_directory: str = "",
): ):
super().__init__() super().__init__()
if ssd_offload and not import_ssd_offload:
raise ImportError(
f"Trying to enable ssd_offload when it was not successfully imported (likely due to old torch version, current = {torch.__version__})"
)
self.ssd_offload = ssd_offload
self._fpw_module = module self._fpw_module = module
self.is_flattened = False self.is_flattened = False
...@@ -239,14 +218,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -239,14 +218,7 @@ class FlattenParamsWrapper(nn.Module):
# Init all flat_params. # Init all flat_params.
for new_p_set in self._param_sets: for new_p_set in self._param_sets:
params, param_infos, shared_param_infos = self._init_flatten_params(new_p_set) params, param_infos, shared_param_infos = self._init_flatten_params(new_p_set)
if ssd_offload: flat_param = FlatParameter(params, params[0].requires_grad)
assert ssd_directory != ""
(handle, fname) = tempfile.mkstemp(dir=ssd_directory, suffix="ssd_buf_param")
flat_param = SsdFlatParameter.from_tensors(tensors=params)
flat_param.allow_unsafe_changes = True
flat_param.set_file_params(fname, 0)
else:
flat_param = FlatParameter(params, params[0].requires_grad)
flat_param._param_infos = param_infos flat_param._param_infos = param_infos
flat_param._shared_param_infos = shared_param_infos flat_param._shared_param_infos = shared_param_infos
self.flat_params.append(flat_param) self.flat_params.append(flat_param)
...@@ -393,13 +365,8 @@ class FlattenParamsWrapper(nn.Module): ...@@ -393,13 +365,8 @@ class FlattenParamsWrapper(nn.Module):
ps = self.get_param_views() ps = self.get_param_views()
param_views = [] param_views = []
for (_, m, n), p in zip(self._param_infos, ps): for (_, m, n), p in zip(self._param_infos, ps):
if self.ssd_offload: setattr(m, n, p) # This will set as plain attr
assert isinstance(p, SsdFlatParameterView) param_views.append(p)
_register_property(m, n, SsdFlatParameterViewProperty(p.parent, p.id))
else:
setattr(m, n, p) # This will set as plain attr
param_views.append(p)
# Save param views for easy access if anyone still wants to access # Save param views for easy access if anyone still wants to access
# parameters of the module. # parameters of the module.
......
...@@ -6,7 +6,6 @@ tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py ...@@ -6,7 +6,6 @@ tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
tests/experimental/nn/test_offload.py tests/experimental/nn/test_offload.py
tests/experimental/nn/test_auto_shard.py tests/experimental/nn/test_auto_shard.py
tests/experimental/optim/test_dynamic_loss_scaler.py tests/experimental/optim/test_dynamic_loss_scaler.py
tests/experimental/nn/test_ssd_offload.py
tests/nn/data_parallel/test_fsdp_shared_weights_mevo.py tests/nn/data_parallel/test_fsdp_shared_weights_mevo.py
tests/nn/data_parallel/test_fsdp_shared_weights.py tests/nn/data_parallel/test_fsdp_shared_weights.py
tests/nn/data_parallel/test_fsdp_pre_backward_hook.py tests/nn/data_parallel/test_fsdp_pre_backward_hook.py
...@@ -50,5 +49,4 @@ tests/nn/pipe/test_dependency.py ...@@ -50,5 +49,4 @@ tests/nn/pipe/test_dependency.py
tests/nn/pipe/test_stream.py tests/nn/pipe/test_stream.py
tests/nn/moe/test_moe_layer.py tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_top2gating.py tests/nn/moe/test_top2gating.py
tests/nn/data_parallel/test_fsdp_offload.py
tests/nn/data_parallel/test_fsdp_fwd_fwd_bwd_bwd.py tests/nn/data_parallel/test_fsdp_fwd_fwd_bwd_bwd.py
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""
Testing SsdFlatParameter and SsdTensorHandle modules.
"""
import filecmp
import functools
import os
import tempfile
import numpy as np
import pytest
import torch
pytestmark = pytest.mark.skip(reason="ssd offload to be removed to simplify the code")
try:
import fairscale.experimental.nn.ssd_offload as so
except ImportError as ie:
# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
pytestmark = pytest.mark.skipif(True, reason=ie.msg)
pass
def _init():
torch.manual_seed(0)
np.random.seed(0)
def test_write_read():
_init()
with tempfile.NamedTemporaryFile() as f:
ref_tensor = torch.rand(128, dtype=torch.float32)
test_tensor = torch.zeros_like(ref_tensor)
assert not torch.equal(ref_tensor, test_tensor)
so.write(ref_tensor, f.name)
so.read(test_tensor, f.name)
assert torch.equal(ref_tensor, test_tensor)
def test_ssd_handle_dispatch_fwd():
_init()
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn(128)
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
ssd_handle.set_file_params(f.name, 0)
ssd_handle.to_file(release_tensor_after_write=True)
assert torch.equal(ssd_handle.to_tensor(), orig_tensor)
# This should trigger the torch_dispatch code and write
# back the results to the file
ssd_handle.add_(1)
plus1_tensor = orig_tensor.add(1)
assert torch.equal(ssd_handle.to_tensor(), plus1_tensor)
def test_ssd_handle_dispatch_bwd():
_init()
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4), requires_grad=True)
orig_copy = orig_tensor.clone().detach().requires_grad_(True)
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
ssd_handle.set_file_params(f.name, 0)
ssd_handle.to_file(release_tensor_after_write=True)
assert torch.equal(ssd_handle.to_tensor(), orig_tensor)
y1 = ssd_handle + 1
y2 = orig_copy + 1
y1.sum().backward()
y2.sum().backward()
assert torch.equal(ssd_handle.grad, orig_copy.grad)
@pytest.mark.skip("broken at head")
def test_ssd_handle_dispatch_bwd_hook():
_init()
def post_backward_hook(name, grad):
print(f"BACKWARD HOOK for tensor {name} CALLED")
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4), requires_grad=True)
orig_copy = orig_tensor.clone().detach().requires_grad_(True)
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
ssd_handle.set_file_params(f.name, 0)
ssd_handle.to_file(release_tensor_after_write=True)
one = torch.ones(1, requires_grad=True).cuda()
orig_copy = ssd_handle.data
cuda_copy = ssd_handle.to("cuda").detach().requires_grad_(True)
ssd_handle.data = cuda_copy
ssd_handle.register_hook(functools.partial(post_backward_hook, "ssd_handle"))
one.register_hook(functools.partial(post_backward_hook, "one"))
y1 = ssd_handle + one
y1.sum().backward()
def test_ssd_handle_train_simple():
_init()
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4), requires_grad=True)
with torch.no_grad():
orig_copy = torch.empty_like(orig_tensor)
orig_copy.copy_(orig_tensor)
orig_copy.requires_grad = True
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
ssd_handle.flush_on_dirty = False
ssd_handle.set_file_params(f.name, 0)
ssd_handle.to_file(release_tensor_after_write=True)
assert torch.equal(ssd_handle.to_tensor(), orig_tensor)
optimizer_ssd = torch.optim.SGD([ssd_handle], lr=0.1)
optimizer_orig = torch.optim.SGD([orig_copy], lr=0.1)
y1 = ssd_handle + 1
optimizer_ssd.zero_grad()
y1.sum().backward()
assert ssd_handle.storage_state is so.StorageState.ON_CPU_CLEAN
optimizer_ssd.step()
assert ssd_handle.storage_state is so.StorageState.ON_CPU_DIRTY
y2 = orig_copy + 1
optimizer_orig.zero_grad()
y2.sum().backward()
optimizer_orig.step()
assert torch.equal(ssd_handle.to_tensor(), orig_copy)
def test_torch_save_load_ssd_flat_param_on_disk():
_init()
orig_file = tempfile.NamedTemporaryFile(prefix="tensor")
checkpoint_file = tempfile.NamedTemporaryFile(prefix="checkpoint", suffix=".pt")
checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir")
# TENSOR_SHAPE = (1024, 1024, 2048)
# use smaller shape for unit tests
TENSOR_SHAPE = (1024, 321)
ref_tensors = [torch.rand(TENSOR_SHAPE, dtype=torch.float32) for i in range(4)]
ssd_handle = so.SsdFlatParameter.from_tensors(ref_tensors, False)
ssd_handle.set_file_params(orig_file.name, 0)
ssd_handle.to_file()
ref_tensors = []
# after deleting ref_tensor, memory usage should be very low
# For save it shouldn't be more than 10x so.DEFAULT_CHUNK_SIZE
with so.CheckpointPathContextManager(override_path=checkpoint_load_directory.name):
so.torch_saver.save(ssd_handle, checkpoint_file.name)
# below line saves file to checkpoint_load_directory/orig_file.name
# Memory usage here should be O(1000 * so.DEFAULT_CHUNK_SIZE)
# 1000x because that's how many elements the python unpickler
# will buffer before passing to the SsdTensor
test_ssd_handle = torch.load(checkpoint_file)
head, tail = os.path.split(orig_file.name)
assert filecmp.cmp(orig_file.name, os.path.join(checkpoint_load_directory.name, tail), shallow=False)
def test_torch_save_load_ssd_flat_param_on_mem():
_init()
orig_file = tempfile.NamedTemporaryFile(prefix="tensor")
checkpoint_file = tempfile.NamedTemporaryFile(prefix="checkpoint", suffix=".pt")
checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir")
# TENSOR_SHAPE = (1024, 1024, 2048)
# use smaller shape for unit tests
TENSOR_SHAPE = (1024, 321)
ref_tensors = [torch.rand(TENSOR_SHAPE, dtype=torch.float32) for i in range(4)]
ssd_handle = so.SsdFlatParameter.from_tensors(ref_tensors, False)
ssd_handle.set_file_params(orig_file.name, 0)
ref_tensors = []
# after deleting ref_tensor, memory usage should be very low
# For save it shouldn't be more than 10x so.DEFAULT_CHUNK_SIZE
with so.CheckpointPathContextManager(override_path=checkpoint_load_directory.name):
so.torch_saver.save(ssd_handle, checkpoint_file.name)
# below line saves file to checkpoint_load_directory/orig_file.name
# Memory usage here should be O(1000 * so.DEFAULT_CHUNK_SIZE)
# 1000x because that's how many elements the python unpickler
# will buffer before passing to the SsdTensor
test_ssd_handle = torch.load(checkpoint_file)
assert torch.equal(ssd_handle, test_ssd_handle)
def test_ssd_param_train_simple():
_init()
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4))
with torch.no_grad():
orig_copy = torch.empty_like(orig_tensor)
orig_copy.copy_(orig_tensor)
param = torch.nn.Parameter(orig_copy)
ssd_param = so.SsdParameter(orig_tensor.shape, orig_tensor.dtype)
ssd_param.point_to_tensor(orig_copy)
ssd_param.flush_on_dirty = False
ssd_param.set_file_params(f.name, 0)
ssd_param.to_file(release_tensor_after_write=True)
assert torch.equal(ssd_param.to_tensor(), orig_tensor)
optimizer_ssd = torch.optim.SGD([ssd_param], lr=0.1)
optimizer_orig = torch.optim.SGD([param], lr=0.1)
y1 = ssd_param + 1
optimizer_ssd.zero_grad()
y1.sum().backward()
# Test to see if Dirty is being calculated correctly when optimizer modifies
# ssd_param
assert ssd_param.storage_state is so.StorageState.ON_CPU_CLEAN
optimizer_ssd.step()
assert ssd_param.storage_state is so.StorageState.ON_CPU_DIRTY
y2 = param + 1
optimizer_orig.zero_grad()
y2.sum().backward()
optimizer_orig.step()
assert torch.equal(ssd_param.to_tensor(), param)
def test_ssd_flat_parameter_basic():
_init()
with tempfile.NamedTemporaryFile() as f:
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refc_param = torch.nn.Parameter(torch.rand(128, dtype=torch.float32))
ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], direct_to_file=False)
ssd_flat_param.set_file_params(f.name, 0)
param_views = list(ssd_flat_param.get_param_views())
assert refa_param.shape == param_views[0].shape
assert refb_param.shape == param_views[1].shape
assert refc_param.shape == param_views[2].shape
assert torch.equal(refa_param, param_views[0])
assert torch.equal(refb_param, param_views[1])
assert torch.equal(refc_param, param_views[2])
ssd_flat_param.to_file()
assert not ssd_flat_param.is_available()
first_value = param_views[0][0][0].item()
assert ssd_flat_param.is_available()
assert first_value == refa_param[0][0].item()
def test_ssd_flat_parameter_view_modify():
_init()
with tempfile.NamedTemporaryFile() as f:
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=False)
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=False)
refc_param = torch.nn.Parameter(torch.rand(128, dtype=torch.float32), requires_grad=False)
ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], direct_to_file=False)
ssd_flat_param.set_file_params(f.name, 0)
ssd_flat_param.flush_on_dirty = False
param_views = list(ssd_flat_param.get_param_views())
assert ssd_flat_param.storage_state == so.StorageState.ON_CPU_DIRTY
ssd_flat_param.to_file()
assert ssd_flat_param.storage_state == so.StorageState.ON_DISK
assert param_views[0].tensor is None
param_views[0] += 0.1
assert ssd_flat_param.storage_state == so.StorageState.ON_CPU_DIRTY
@pytest.mark.skip("broken at head")
def test_ssd_flat_parameter_view_bwd():
_init()
hooks_called = []
def post_backward_hook(name, hooks_called, *grads):
print(f"BACKWARD HOOK for tensor {name} CALLED")
hooks_called.append(name)
with tempfile.NamedTemporaryFile() as f:
refa_param = (
torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=True)
.to("cpu")
.detach()
.requires_grad_()
)
refb_param = (
torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=True)
.to("cpu")
.detach()
.requires_grad_()
)
refc_param = (
torch.nn.Parameter(torch.rand(128, dtype=torch.float32), requires_grad=True)
.to("cpu")
.detach()
.requires_grad_()
)
ssd_flat_param = so.SsdFlatParameter.from_tensors(
[refa_param, refb_param, refc_param], direct_to_file=True, filename=f.name, offset=0
)
orig_copy = ssd_flat_param.data
cuda_copy = ssd_flat_param.to("cuda").detach().requires_grad_()
cpu_copy = ssd_flat_param.to("cpu").detach().requires_grad_()
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_orig", hooks_called))
ssd_flat_param.data = cuda_copy
one = torch.ones(1, requires_grad=True, device=ssd_flat_param.device)
y1 = ssd_flat_param.views[0] + one
y2 = cuda_copy + 1
# ssd_flat_param.to_file()
# ssd_flat_param.data = orig_copy
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_cuda", hooks_called))
ssd_flat_param.views[0].register_hook(
functools.partial(post_backward_hook, "ssd_flat_param.views[0]", hooks_called)
)
ssd_flat_param.register_hook(functools.partial(post_backward_hook, "ssd_flat_param", hooks_called))
one.register_hook(functools.partial(post_backward_hook, "one", hooks_called))
y1.sum().backward()
y2.sum().backward()
assert "GradAccumulation_cuda" in hooks_called
assert "ssd_flat_param.views[0]" in hooks_called
assert "ssd_flat_param" in hooks_called
assert "one" in hooks_called
@pytest.mark.skip("broken at head")
def test_ssd_flat_parameter_view_bwd_parameterization():
_init()
hooks_called = []
def post_backward_hook(name, hooks_called, *grads):
print(f"BACKWARD HOOK for tensor {name} CALLED")
hooks_called.append(name)
with tempfile.NamedTemporaryFile() as f:
layer1 = torch.nn.Linear(32, 4, bias=False)
layer2 = torch.nn.Linear(32, 4, bias=False)
layer3 = torch.nn.Linear(128, 1, bias=False)
ssd_flat_param = so.SsdFlatParameter.from_tensors(
[layer1.weight, layer2.weight, layer3.weight], direct_to_file=False, filename=f.name, offset=0
)
torch.nn.utils.parametrize.register_parametrization(
layer1, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 0)
)
torch.nn.utils.parametrize.register_parametrization(
layer2, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 1)
)
torch.nn.utils.parametrize.register_parametrization(
layer3, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 2)
)
orig_copy = ssd_flat_param.data
cuda_copy = ssd_flat_param.to("cuda").detach().requires_grad_()
cpu_copy = ssd_flat_param.to("cpu").detach().requires_grad_()
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_orig", hooks_called))
ssd_flat_param.to_file(release_tensor_after_write=False)
ssd_flat_param.data = cuda_copy
one = torch.ones(layer1.weight.shape, requires_grad=True, device=ssd_flat_param.device)
y1 = layer1.forward(one)
y2 = cuda_copy + 1
# ssd_flat_param.to_file()
# ssd_flat_param.data = orig_copy
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_cuda", hooks_called))
ssd_flat_param.views[0].register_hook(
functools.partial(post_backward_hook, "ssd_flat_param.views[0]", hooks_called)
)
ssd_flat_param.register_hook(functools.partial(post_backward_hook, "ssd_flat_param", hooks_called))
one.register_hook(functools.partial(post_backward_hook, "one", hooks_called))
y1.sum().backward()
y2.sum().backward()
assert "GradAccumulation_cuda" in hooks_called
assert "ssd_flat_param.views[0]" in hooks_called
assert "ssd_flat_param" in hooks_called
assert "one" in hooks_called
def test_ssd_flat_parameter_direct_to_file():
_init()
with tempfile.NamedTemporaryFile() as f:
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refc_param = torch.nn.Parameter(torch.rand(128, dtype=torch.float32))
ssd_flat_param = so.SsdFlatParameter.from_tensors(
[refa_param, refb_param, refc_param], direct_to_file=True, filename=f.name, offset=0
)
param_views = list(ssd_flat_param.get_param_views())
assert refa_param.shape == param_views[0].shape
assert refb_param.shape == param_views[1].shape
assert refc_param.shape == param_views[2].shape
assert torch.equal(refa_param, param_views[0])
assert torch.equal(refb_param, param_views[1])
assert torch.equal(refc_param, param_views[2])
ssd_flat_param.to_file()
assert not ssd_flat_param.is_available()
first_value = param_views[0][0][0].item()
assert ssd_flat_param.is_available()
assert first_value == refa_param[0][0].item()
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment