Unverified Commit e71d2570 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[cleanup] remove ssd offload to simplify the FSDP code (#1080)



* simlificed the readme

* clean up ssd offload

* try to fix readthedocs
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent f4fcee7e
# .readthedocs.yaml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
# Required
version: 2
# We need python > 3.8 due to a dependency on numpy.
build:
os: ubuntu-20.04
tools:
python: "3.9"
# You can also specify other tool versions:
# nodejs: "16"
# rust: "1.55"
# golang: "1.17"
# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/source/conf.py
# If using Sphinx, optionally build your docs in additional formats such as PDF
# formats:
# - pdf
# Optionally declare the Python requirements required to build your docs
python:
install:
- requirements: docs/requirements.txt
......@@ -25,23 +25,6 @@ FairScale was designed with the following values in mind:
[![Explain Like I’m 5: FairScale](https://img.youtube.com/vi/oDt7ebOwWIc/0.jpg)](https://www.youtube.com/watch?v=oDt7ebOwWIc)
## What's New:
* March 2022 [fairscale 0.4.6 was released](https://github.com/facebookresearch/fairscale/releases/tag/v0.4.6).
* We have support for CosFace's LMCL in MEVO. This is a loss function that is suitable for large number of prediction target classes.
* January 2022 [fairscale 0.4.5 was released](https://github.com/facebookresearch/fairscale/releases/tag/v0.4.5).
* We have experimental support for layer wise gradient scaling.
* We enabled reduce_scatter operation overlapping in FSDP backward propagation.
* December 2021 [fairscale 0.4.4 was released](https://github.com/facebookresearch/fairscale/releases/tag/v0.4.4).
* FairScale is tested with the following PyTorch versions (with CUDA 11.2): 1.8.1, 1.10.0 and 1.11.0.dev20211101+cu111.
* November 2021 [fairscale 0.4.3 was released](https://github.com/facebookresearch/fairscale/releases/tag/v0.4.3).
* We have experimental support for offloading params to disk when using the FSDP API for evaluation workloads.
* We have an experimental layer that fuses multiple layers together to support large vocab size trainings.
* November 2021 [fairscale 0.4.2 was released](https://github.com/facebookresearch/fairscale/releases/tag/v0.4.2).
* We have a new experimental API called the LayerwiseMemoryTracker to help track, visualize and suggest fixes for memory issues occurring during the forward/backward pass of your models.
* Introducing SlowMoDistributedDataParallel API, a distributed training wrapper that is useful on clusters with slow network interconnects (e.g. Ethernet).
* September 2021 [`master` branch renamed to `main`](https://github.com/github/renaming).
## Installation
To install FairScale, please see the following [instructions](https://github.com/facebookresearch/fairscale/blob/main/docs/source/installation_instructions.rst).
......@@ -50,134 +33,26 @@ You should be able to install a package with pip or conda, or build directly fro
## Getting Started
The full [documentation](https://fairscale.readthedocs.io/) contains instructions for getting started, deep dives and tutorials about the various FairScale APIs.
## Examples
Here are a few sample snippets from a subset of FairScale offerings:
### Pipe
Run a 4-layer model on 2 GPUs. The first two layers run on cuda:0 and the next two layers run on cuda:1.
```python
import torch
import fairscale
model = torch.nn.Sequential(a, b, c, d)
model = fairscale.nn.Pipe(model, balance=[2, 2], devices=[0, 1], chunks=8)
```
### Optimizer state sharding (ZeRO)
See a more complete example [here](https://github.com/facebookresearch/fairscale/blob/main/benchmarks/oss.py), but a minimal example could look like the following :
```python
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from fairscale.optim.oss import OSS
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
def train(
rank: int,
world_size: int,
epochs: int):
# DDP init example
dist.init_process_group(backend='nccl', init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
# Problem statement
model = myAwesomeModel().to(rank)
dataloader = mySuperFastDataloader()
loss_fn = myVeryRelevantLoss()
base_optimizer = torch.optim.SGD # pick any pytorch compliant optimizer here
base_optimizer_arguments = {} # pass any optimizer specific arguments here, or directly below when instantiating OSS
# Wrap the optimizer in its state sharding brethren
optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments)
# Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks
model = ShardedDDP(model, optimizer)
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
for e in range(epochs):
for batch in dataloader:
# Train
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
loss.backward()
optimizer.step()
dist.destroy_process_group()
if __name__ == "__main__":
# Supposing that WORLD_SIZE and EPOCHS are somehow defined somewhere
mp.spawn(
train,
args=(
WORLD_SIZE,
EPOCHS,
),
nprocs=WORLD_SIZE,
join=True,
)
```
### AdaScale SGD
AdaScale can be used to wrap a SGD optimizer and to be used in DDP (Distributed Data Parallel)
training or non-DDP with gradient accumulation. The benefit is to re-use the same LR
schedule from a baseline batch size when effective batch size is bigger.
Note that AdaScale does _not_ help increase per-GPU batch size.
```python
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR # or your scheduler
from fairscale.optim import AdaScale
...
optim = AdaScale(SGD(model.parameters(), lr=0.1))
scheduler = LambdaLR(optim, ...)
...
# Note: the train loop should be with DDP or with gradient accumulation.
last_epoch = 0
step = 0
done = False
while not done:
for sample in dataset:
...
step += optim.gain()
optim.step()
epoch = step // len(dataset)
if last_epoch != epoch:
scheduler.step()
last_epoch = epoch
if epoch > max_epoch:
done = True
```
## FSDP
Primary goal is to allow scaling to bigger batch sizes without losing model accuracy.
(However, training time might be longer comparing to without AdaScale.)
At a high level, we want ML researchers to:
* go parallel more easily (i.e. no need to find new learning rate schedules)
* not worrying about losing accuracy
* potentially higher GPU efficiency (fewer steps, less networking overhead, etc.)
FullyShardedDataParallel (FSDP) is the recommended method for scaling to large NN models.
This library has been [upstreamed to PyTorch](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/).
The version of FSDP here is for historical references as well as for experimenting with
new and crazy ideas in research of scaling techniques. Please see the following blog
for [how to use FairScale FSDP and how does it work](https://engineering.fb.com/2021/07/15/open-source/fsdp/).
## Testing
We use circleci to test FairScale with the following PyTorch versions (with CUDA 11.2):
* the latest stable release (1.10.0)
* the latest LTS release (1.8.1)
* a recent nightly release (1.11.0.dev20211101+cu111)
* the latest stable release (e.g. 1.10.0)
* the latest LTS release (e.g. 1.8.1)
* a recent nightly release (e.g. 1.11.0.dev20211101+cu111)
Please create an [issue](https://github.com/facebookresearch/fairscale/issues) if you are having trouble with installation.
## Contributors
We welcome outside contributions! Please see the [CONTRIBUTING](CONTRIBUTING.md) instructions for how you can contribute to FairScale.
We welcome contributions! Please see the [CONTRIBUTING](CONTRIBUTING.md) instructions for how you can contribute to FairScale.
## License
......@@ -198,22 +73,9 @@ If you use FairScale in your publication, please cite it by using the following
```BibTeX
@Misc{FairScale2021,
author = {Mandeep Baines and Shruti Bhosale and Vittorio Caggiano and Naman Goyal and Siddharth Goyal and Myle Ott and Benjamin Lefaudeux and Vitaliy Liptchinsky and Mike Rabbat and Sam Sheiffer and Anjali Sridhar and Min Xu},
author = {FairScale authors},
title = {FairScale: A general purpose modular PyTorch library for high performance and large scale training},
howpublished = {\url{https://github.com/facebookresearch/fairscale}},
year = {2021}
}
```
## FAQ
1. If you experience an error indicating a default branch does not exist, it probably due to the latest update, switching the default branch from "master" to "main"
```
error: pathspec 'non-existing-branch' did not match any file(s) known to git
```
Please run the following commands to update to the main branch.
```
git branch -m master main
git fetch origin
git branch -u origin/main main
git remote set-head origin -a
```
......@@ -25,7 +25,6 @@ from torch.optim import Adam
from benchmarks.golden_configs.lm_wikitext2 import FSDP as lm_wikitext2
from fairscale.nn import auto_wrap, default_auto_wrap_policy, enable_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import OffloadConfig
RPC_PORT = 29501
......@@ -95,9 +94,6 @@ def get_lm_model(args, device, config):
nhid = config["nhid"]
ndecoder = config["num_decoder_layers"]
if args.ssd_offload:
return transformer_lm.TransformerLM(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder)
else:
return transformer_lm.TransformerLM(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
......@@ -200,7 +196,7 @@ def train(model_config, model, benchmark_config, model_specs, args):
if i > 0:
total_tokens += source.numel()
if args.benchmark_eval or args.ssd_offload:
if args.benchmark_eval:
input = source.cuda()
target = target.cuda()
output = model(input)
......@@ -250,7 +246,6 @@ def get_number_of_words(data):
def benchmark_language_model(model_config, model, benchmark_config, model_specs, args):
# TODO(anj): Uncomment and add a check for regression once we have a couple of runs.
golden_config = get_golden_config(args.model_name, args)
epoch = benchmark_config["epochs"]
start_time = time.time()
......@@ -358,8 +353,6 @@ def benchmark_fsdp(rank, args, world_size):
model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
model = model_config["model"]
config = {}
if args.ssd_offload:
config["offload_config"] = OffloadConfig(offload_type="ssd_offload")
if args.full_fp16:
config["compute_dtype"] = torch.float16
......@@ -386,7 +379,6 @@ parser.add_argument("--max_batch", type=int, default=4, help="Max number of batc
parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.")
parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.")
parser.add_argument(
# TODO(anj-s): In the process of adding more models and hence the requirement for a flag.
"--model_name",
default="lm",
help="Language Model(LM) used to benchmark FSDP.",
......@@ -394,7 +386,6 @@ parser.add_argument(
parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
parser.add_argument("--enable_auto_wrap", action="store_true", default=False, help="Use auto_wrap with FSDP")
parser.add_argument("--benchmark_eval", action="store_true", default=False, help="Benchmark evaluation workflow.")
parser.add_argument("--ssd_offload", action="store_true", default=False, help="Benchmark ssd_offload workflow.")
parser.add_argument("--full_fp16", action="store_true", default=False, help="Benchmark in full fp16 mode.")
if __name__ == "__main__":
......
......@@ -30,7 +30,7 @@ sys.path.insert(0, os.path.abspath("../.."))
# -- Project information -----------------------------------------------------
project = "FairScale"
copyright = "2020-2021, Facebook/Meta AI Research"
copyright = "2020-2022, Facebook/Meta AI Research"
author = "Facebook/Meta AI Research"
# -- General configuration ---------------------------------------------------
......@@ -68,7 +68,7 @@ autodoc_inherit_docstrings = False
autodoc_member_order = "bysource"
intersphinx_mapping = {
"python": ("https://docs.python.org/3.6", None),
"python": ("https://docs.python.org/3.8", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"torch": ("https://pytorch.org/docs/stable/", None),
}
......
This diff is collapsed.
......@@ -9,7 +9,6 @@ import torch.distributed as dist
from .fully_sharded_data_parallel import (
FullyShardedDataParallel,
OffloadConfig,
TrainingState,
auto_wrap_bn,
get_fsdp_instances,
......
......@@ -5,13 +5,11 @@
import contextlib
import copy
from dataclasses import dataclass
from enum import Enum, auto
import functools
import logging
from math import inf
import os
import tempfile
import time
import traceback
import typing
......@@ -69,15 +67,6 @@ if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
else:
enable_nccl_base_collectives = True
try:
import fairscale.experimental.nn.ssd_offload as ssd_offload
import_ssd_offload = True
except ImportError:
# The latest nightly PyTorch version required
import_ssd_offload = False
pass
class TrainingState(Enum):
"""
......@@ -107,19 +96,6 @@ class TrainingState(Enum):
SUMMON_FULL_PARAMS = auto()
# Data classes containing FSDP parameter constructs
# Offload config for specifying SSD options (initially at least)
@dataclass
class OffloadConfig:
"""Class for specifying all arguments related to offloading parameters."""
# Offload type: currently only supports: "ssd_offload"
offload_type: Optional[str] = None
# Path to the directory for storing parameters offloaded to disk.
dir: Optional[str] = None
class FullyShardedDataParallel(nn.Module):
"""
A wrapper for sharding Module parameters across data parallel workers. This
......@@ -302,10 +278,6 @@ class FullyShardedDataParallel(nn.Module):
cpu_offload (bool, Optional):
if ``True``, offload params to CPU. Note: This arg will be deprecated in favor of
*``move_params_to_cpu``* in an upcoming release.
offload_config (OffloadConfig):
The `OffloadConfig` object is used to specify the type of offload (i.e SSD, CPU) and
other required knobs when offloading parameters from GPU. Currently the OffloadConfig
only supports specifying SSD offload as an option. Note: This is an experimental feature.
state_dict_on_rank_0_only (bool):
When set to ``True``, ``model.state_dict()`` will only returns full state dict on
rank 0 and return empty dict non-rank 0, which allow FullyShardedDataParallel to
......@@ -342,7 +314,6 @@ class FullyShardedDataParallel(nn.Module):
force_input_to_fp32: bool = False,
verbose: bool = False,
cpu_offload: bool = False,
offload_config: Optional[OffloadConfig] = None,
state_dict_on_rank_0_only: bool = False,
gradient_predivide_factor: Optional[float] = None,
allow_reset_parameters: bool = False,
......@@ -414,12 +385,6 @@ class FullyShardedDataParallel(nn.Module):
self.force_input_to_fp32 = force_input_to_fp32
self.verbose = verbose
self.state_dict_on_rank_0_only = state_dict_on_rank_0_only
# Experimental feature for now. Use at your own risk.
self.ssd_offload = True if offload_config and offload_config.offload_type == "ssd_offload" else False
if self.ssd_offload and not import_ssd_offload:
raise ImportError(
f"Trying to enable ssd_offload when it was not successfully imported (likely due to old torch version, current = {torch.__version__})"
)
self.gradient_predivide_factor: float = gradient_predivide_factor or self._get_gradient_predivide_factor(
self.world_size
......@@ -433,9 +398,6 @@ class FullyShardedDataParallel(nn.Module):
if self.fp32_reduce_scatter and not self.mixed_precision:
raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
if self.ssd_offload and not self.flatten_parameters:
raise ValueError(f"offload type: '{offload_config.offload_type}' requires flatten_parameters=True")
# skip validation if the process group was created above
if process_group:
validate_process_group(self.compute_device, self.process_group)
......@@ -456,16 +418,7 @@ class FullyShardedDataParallel(nn.Module):
self._has_params = len(params) > 0
self._has_shared_params = False
# TODO(anj): Should we conditionally do this only if we have params?
# TODO(anj): Figure out if we can allocate the buffer during sharding.
self.buffer_size = sum(p.numel() for p in params)
self.ssd_directory = tempfile.gettempdir()
if self.ssd_offload:
assert import_ssd_offload, "We need to import ssd_offload.py to enable the `ssd_offload` feature."
if offload_config and offload_config.dir:
self.ssd_directory = offload_config.dir
self.move_grads_to_cpu = True
self.move_params_to_cpu = True
# For now, it is either all flatten or none flatten. This will be extended to
# multiple flatten groups in my next PR.
......@@ -478,9 +431,7 @@ class FullyShardedDataParallel(nn.Module):
param_name_groups = [param_names]
del param_names
self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(
module, param_list=to_be_flatten_params, ssd_offload=self.ssd_offload, ssd_directory=self.ssd_directory
)
self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(module, param_list=to_be_flatten_params)
del module # free original module in case it helps garbage collection
# Now, in this FSDP wrapper class, we keep a list of to-be-flatten and not-to-be-flatten
......@@ -531,8 +482,6 @@ class FullyShardedDataParallel(nn.Module):
# Flag to indicate whether state_dict() should automatically summon the
# full params. This defaults to True, but may be set to False if the
# user explicitly requests the local state dict via local_state_dict().
# TODO(anj): This should by default be set to False for ssd_offload=True
# unless we are in the summon_full_params context.
self._return_full_state_dict = True
init_end = time.time()
......@@ -544,11 +493,6 @@ class FullyShardedDataParallel(nn.Module):
# This is reset at the end of the backward pass.
self._pre_backward_hook_has_run = False
# Free all params at the end of initialization.
if self.ssd_offload:
for m in get_fsdp_instances(self):
m._free_ssd_offload()
def _get_gradient_predivide_factor(self, world_size: int) -> float:
factor: int = 1
while world_size % factor == 0 and world_size / factor > factor:
......@@ -785,7 +729,6 @@ class FullyShardedDataParallel(nn.Module):
p._orig_size = p.data.size()
if not p._is_sharded:
if not self.ssd_offload:
p._is_sharded = False
self.numel_padded_per_param.append(0)
continue
......@@ -797,10 +740,6 @@ class FullyShardedDataParallel(nn.Module):
p.data, num_padded = self._get_shard(p.data)
self.numel_padded_per_param.append(num_padded)
if self.ssd_offload:
assert isinstance(p, ssd_offload.SsdParameter)
p.to_file()
else:
free_storage_(orig_data)
assert len(self.numel_padded_per_param) == len(self.params)
......@@ -1014,21 +953,11 @@ class FullyShardedDataParallel(nn.Module):
backup = self._return_full_state_dict
self._return_full_state_dict = False
if self.ssd_offload:
# Move params from disk to memory before returning the local state dict.
self._move_params_to_memory()
try:
yield
finally:
self._return_full_state_dict = backup
def _move_params_to_memory(self) -> None:
"""Move params from disk to CPU."""
for p in self.params:
assert isinstance(p, ssd_offload.SsdParameter)
p.to_tensor()
def _load_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
) -> NamedTuple:
......@@ -1276,21 +1205,10 @@ class FullyShardedDataParallel(nn.Module):
# Copy any changes made to the full params back into
# the corresponding local shards.
local_shard, _ = self._get_shard(full_tensor)
if self.ssd_offload:
assert isinstance(p, ssd_offload.SsdParameter)
self._ssd_offload_reset_param_device(p)
p.point_to_tensor(local_shard.view_as(p._fp32_shard).cpu())
else:
p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
if safe_to_free:
free_storage_(full_tensor)
self.has_full_params = False
if self.ssd_offload:
# Store tensors in the SSD buffer and free param storage.
for p in self.params:
assert isinstance(p, ssd_offload.SsdParameter)
p.to_file()
else:
self._use_fp32_param_shard()
self.training_state = TrainingState.IDLE
......@@ -1366,11 +1284,6 @@ class FullyShardedDataParallel(nn.Module):
return
# A single shard of the parameters in full precision.
# TODO(another-pjohnson) - I believe this will cause memory leakage with ssd
# p.data returns a pointer to a handle, and that handle has it's
# ref count incremented by p._fp32_shard. So this tensor will
# never be freed even if we do p.to_disk(). investigate after
# PR #887 is merged
p._fp32_shard = p.data
if self.mixed_precision:
......@@ -1378,9 +1291,6 @@ class FullyShardedDataParallel(nn.Module):
if self.move_params_to_cpu:
assert p._fp32_shard.device == torch.device("cpu"), self
# We don't pin memory when using ssd_offload since that results in OOM when
# the memory requirements of a model are larger than host memory.
if not self.ssd_offload:
# If we plan to keep the FP32 parameters on CPU, then pinning
# memory allows us to later use non-blocking transfers when moving
# the FP32 param shard to compute_device.
......@@ -1423,15 +1333,6 @@ class FullyShardedDataParallel(nn.Module):
# pass. In this case, it's important to pre-allocate the CPU grad
# shard in pinned memory so that we can do a non-blocking transfer.
# This is only needed during training and not evaluation.
if self.ssd_offload:
assert isinstance(p, ssd_offload.SsdParameter)
# Gradients also need to be offloaded to SSD otherwise it can result in
# OOMs when the memory requirements of a model are larger than host memory.
p._cpu_grad = ssd_offload.SsdTensorHandle.from_tensor(torch.zeros_like(p.data, device="cpu"))
p._cpu_grad.allow_unsafe_changes = True
p._cpu_grad.set_file_params(p.filename + "_grad", 0)
p._cpu_grad.to_file()
else:
p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory()
def _set_is_root(self) -> None:
......@@ -1576,17 +1477,8 @@ class FullyShardedDataParallel(nn.Module):
if self.clear_autocast_cache:
torch.clear_autocast_cache()
self._free_ssd_offload()
return outputs
@torch.no_grad()
def _free_ssd_offload(self) -> None:
if self.ssd_offload:
for p in self.params:
assert isinstance(p, ssd_offload.SsdParameter)
p.to_file(permit_when_tensor_none=True)
def _register_pre_backward_hooks(self, outputs: Any) -> Any:
"""Register pre-backward hook to run before the wrapped module's
backward. Hooks should be attached to all outputs from the forward.
......@@ -1990,7 +1882,6 @@ class FullyShardedDataParallel(nn.Module):
# Update root and nested FSDP's hooks and flags.
for m in get_fsdp_instances(self):
_finalize_parameters(m)
m._free_ssd_offload()
m._pre_backward_hook_has_run = False
if any(p.requires_grad for p in m.parameters()):
# Check if the module has params and if any of them has
......@@ -2071,15 +1962,6 @@ class FullyShardedDataParallel(nn.Module):
# Trim any padding and reshape to match original size.
p.data = p.data[: p._orig_size.numel()].view(p._orig_size)
if self.ssd_offload:
for p in self.params:
assert isinstance(p, ssd_offload.SsdParameter)
if not p.is_available():
self._ssd_offload_reset_param_device(p)
p.to_tensor()
self.has_full_params = False
if self._has_shared_params:
# self.has_full_params flag can be out of sync if a shared param is
# sharded by another FSDP instance. An example is that in eval case
......@@ -2366,24 +2248,12 @@ class FullyShardedDataParallel(nn.Module):
return consolidated_weights
@torch.no_grad()
def _ssd_offload_reset_param_device(self, param: Parameter) -> None:
assert isinstance(param, ssd_offload.SsdParameter)
if param.device != torch.device("cpu"):
param.data = param._fp32_shard
param.tensor = None
@torch.no_grad()
def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
"""Use FP32 shard for a list of params."""
if params is None:
params = self.params
for p in params:
if import_ssd_offload and self.ssd_offload:
assert isinstance(p, ssd_offload.SsdParameter)
self._ssd_offload_reset_param_device(p)
p.to_tensor()
else:
p.data = p._fp32_shard
@torch.no_grad()
......@@ -2395,9 +2265,6 @@ class FullyShardedDataParallel(nn.Module):
for p in params:
assert p._fp16_shard is not None
alloc_storage_(p._fp16_shard, size=p._fp32_shard.size())
if self.ssd_offload:
p._fp16_shard.copy_(p.to(p._fp16_shard.device, non_blocking=True))
else:
p._fp16_shard.copy_(
# If move_params_to_cpu is True, this will be non-blocking
# because _fp32_shard is pinned, otherwise it's a no-op.
......
......@@ -8,7 +8,6 @@
from contextlib import contextmanager
from itertools import chain
import tempfile
import typing
from typing import (
TYPE_CHECKING,
......@@ -31,19 +30,6 @@ import torch
from torch import Tensor
import torch.nn as nn
try:
from fairscale.experimental.nn.ssd_offload import (
SsdFlatParameter,
SsdFlatParameterView,
SsdFlatParameterViewProperty,
_register_property,
)
import_ssd_offload = True
except ImportError:
import_ssd_offload = False
pass
from fairscale.internal.state_dict import replace_by_prefix_
if TYPE_CHECKING:
......@@ -169,15 +155,8 @@ class FlattenParamsWrapper(nn.Module):
module: nn.Module,
param_list: ParamGroups = None,
flat_param_names: Optional[List[str]] = None,
ssd_offload: bool = False,
ssd_directory: str = "",
):
super().__init__()
if ssd_offload and not import_ssd_offload:
raise ImportError(
f"Trying to enable ssd_offload when it was not successfully imported (likely due to old torch version, current = {torch.__version__})"
)
self.ssd_offload = ssd_offload
self._fpw_module = module
self.is_flattened = False
......@@ -239,13 +218,6 @@ class FlattenParamsWrapper(nn.Module):
# Init all flat_params.
for new_p_set in self._param_sets:
params, param_infos, shared_param_infos = self._init_flatten_params(new_p_set)
if ssd_offload:
assert ssd_directory != ""
(handle, fname) = tempfile.mkstemp(dir=ssd_directory, suffix="ssd_buf_param")
flat_param = SsdFlatParameter.from_tensors(tensors=params)
flat_param.allow_unsafe_changes = True
flat_param.set_file_params(fname, 0)
else:
flat_param = FlatParameter(params, params[0].requires_grad)
flat_param._param_infos = param_infos
flat_param._shared_param_infos = shared_param_infos
......@@ -393,11 +365,6 @@ class FlattenParamsWrapper(nn.Module):
ps = self.get_param_views()
param_views = []
for (_, m, n), p in zip(self._param_infos, ps):
if self.ssd_offload:
assert isinstance(p, SsdFlatParameterView)
_register_property(m, n, SsdFlatParameterViewProperty(p.parent, p.id))
else:
setattr(m, n, p) # This will set as plain attr
param_views.append(p)
......
......@@ -6,7 +6,6 @@ tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
tests/experimental/nn/test_offload.py
tests/experimental/nn/test_auto_shard.py
tests/experimental/optim/test_dynamic_loss_scaler.py
tests/experimental/nn/test_ssd_offload.py
tests/nn/data_parallel/test_fsdp_shared_weights_mevo.py
tests/nn/data_parallel/test_fsdp_shared_weights.py
tests/nn/data_parallel/test_fsdp_pre_backward_hook.py
......@@ -50,5 +49,4 @@ tests/nn/pipe/test_dependency.py
tests/nn/pipe/test_stream.py
tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_top2gating.py
tests/nn/data_parallel/test_fsdp_offload.py
tests/nn/data_parallel/test_fsdp_fwd_fwd_bwd_bwd.py
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""
Testing SsdFlatParameter and SsdTensorHandle modules.
"""
import filecmp
import functools
import os
import tempfile
import numpy as np
import pytest
import torch
pytestmark = pytest.mark.skip(reason="ssd offload to be removed to simplify the code")
try:
import fairscale.experimental.nn.ssd_offload as so
except ImportError as ie:
# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
pytestmark = pytest.mark.skipif(True, reason=ie.msg)
pass
def _init():
torch.manual_seed(0)
np.random.seed(0)
def test_write_read():
_init()
with tempfile.NamedTemporaryFile() as f:
ref_tensor = torch.rand(128, dtype=torch.float32)
test_tensor = torch.zeros_like(ref_tensor)
assert not torch.equal(ref_tensor, test_tensor)
so.write(ref_tensor, f.name)
so.read(test_tensor, f.name)
assert torch.equal(ref_tensor, test_tensor)
def test_ssd_handle_dispatch_fwd():
_init()
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn(128)
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
ssd_handle.set_file_params(f.name, 0)
ssd_handle.to_file(release_tensor_after_write=True)
assert torch.equal(ssd_handle.to_tensor(), orig_tensor)
# This should trigger the torch_dispatch code and write
# back the results to the file
ssd_handle.add_(1)
plus1_tensor = orig_tensor.add(1)
assert torch.equal(ssd_handle.to_tensor(), plus1_tensor)
def test_ssd_handle_dispatch_bwd():
_init()
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4), requires_grad=True)
orig_copy = orig_tensor.clone().detach().requires_grad_(True)
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
ssd_handle.set_file_params(f.name, 0)
ssd_handle.to_file(release_tensor_after_write=True)
assert torch.equal(ssd_handle.to_tensor(), orig_tensor)
y1 = ssd_handle + 1
y2 = orig_copy + 1
y1.sum().backward()
y2.sum().backward()
assert torch.equal(ssd_handle.grad, orig_copy.grad)
@pytest.mark.skip("broken at head")
def test_ssd_handle_dispatch_bwd_hook():
_init()
def post_backward_hook(name, grad):
print(f"BACKWARD HOOK for tensor {name} CALLED")
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4), requires_grad=True)
orig_copy = orig_tensor.clone().detach().requires_grad_(True)
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
ssd_handle.set_file_params(f.name, 0)
ssd_handle.to_file(release_tensor_after_write=True)
one = torch.ones(1, requires_grad=True).cuda()
orig_copy = ssd_handle.data
cuda_copy = ssd_handle.to("cuda").detach().requires_grad_(True)
ssd_handle.data = cuda_copy
ssd_handle.register_hook(functools.partial(post_backward_hook, "ssd_handle"))
one.register_hook(functools.partial(post_backward_hook, "one"))
y1 = ssd_handle + one
y1.sum().backward()
def test_ssd_handle_train_simple():
_init()
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4), requires_grad=True)
with torch.no_grad():
orig_copy = torch.empty_like(orig_tensor)
orig_copy.copy_(orig_tensor)
orig_copy.requires_grad = True
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
ssd_handle.flush_on_dirty = False
ssd_handle.set_file_params(f.name, 0)
ssd_handle.to_file(release_tensor_after_write=True)
assert torch.equal(ssd_handle.to_tensor(), orig_tensor)
optimizer_ssd = torch.optim.SGD([ssd_handle], lr=0.1)
optimizer_orig = torch.optim.SGD([orig_copy], lr=0.1)
y1 = ssd_handle + 1
optimizer_ssd.zero_grad()
y1.sum().backward()
assert ssd_handle.storage_state is so.StorageState.ON_CPU_CLEAN
optimizer_ssd.step()
assert ssd_handle.storage_state is so.StorageState.ON_CPU_DIRTY
y2 = orig_copy + 1
optimizer_orig.zero_grad()
y2.sum().backward()
optimizer_orig.step()
assert torch.equal(ssd_handle.to_tensor(), orig_copy)
def test_torch_save_load_ssd_flat_param_on_disk():
_init()
orig_file = tempfile.NamedTemporaryFile(prefix="tensor")
checkpoint_file = tempfile.NamedTemporaryFile(prefix="checkpoint", suffix=".pt")
checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir")
# TENSOR_SHAPE = (1024, 1024, 2048)
# use smaller shape for unit tests
TENSOR_SHAPE = (1024, 321)
ref_tensors = [torch.rand(TENSOR_SHAPE, dtype=torch.float32) for i in range(4)]
ssd_handle = so.SsdFlatParameter.from_tensors(ref_tensors, False)
ssd_handle.set_file_params(orig_file.name, 0)
ssd_handle.to_file()
ref_tensors = []
# after deleting ref_tensor, memory usage should be very low
# For save it shouldn't be more than 10x so.DEFAULT_CHUNK_SIZE
with so.CheckpointPathContextManager(override_path=checkpoint_load_directory.name):
so.torch_saver.save(ssd_handle, checkpoint_file.name)
# below line saves file to checkpoint_load_directory/orig_file.name
# Memory usage here should be O(1000 * so.DEFAULT_CHUNK_SIZE)
# 1000x because that's how many elements the python unpickler
# will buffer before passing to the SsdTensor
test_ssd_handle = torch.load(checkpoint_file)
head, tail = os.path.split(orig_file.name)
assert filecmp.cmp(orig_file.name, os.path.join(checkpoint_load_directory.name, tail), shallow=False)
def test_torch_save_load_ssd_flat_param_on_mem():
_init()
orig_file = tempfile.NamedTemporaryFile(prefix="tensor")
checkpoint_file = tempfile.NamedTemporaryFile(prefix="checkpoint", suffix=".pt")
checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir")
# TENSOR_SHAPE = (1024, 1024, 2048)
# use smaller shape for unit tests
TENSOR_SHAPE = (1024, 321)
ref_tensors = [torch.rand(TENSOR_SHAPE, dtype=torch.float32) for i in range(4)]
ssd_handle = so.SsdFlatParameter.from_tensors(ref_tensors, False)
ssd_handle.set_file_params(orig_file.name, 0)
ref_tensors = []
# after deleting ref_tensor, memory usage should be very low
# For save it shouldn't be more than 10x so.DEFAULT_CHUNK_SIZE
with so.CheckpointPathContextManager(override_path=checkpoint_load_directory.name):
so.torch_saver.save(ssd_handle, checkpoint_file.name)
# below line saves file to checkpoint_load_directory/orig_file.name
# Memory usage here should be O(1000 * so.DEFAULT_CHUNK_SIZE)
# 1000x because that's how many elements the python unpickler
# will buffer before passing to the SsdTensor
test_ssd_handle = torch.load(checkpoint_file)
assert torch.equal(ssd_handle, test_ssd_handle)
def test_ssd_param_train_simple():
_init()
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4))
with torch.no_grad():
orig_copy = torch.empty_like(orig_tensor)
orig_copy.copy_(orig_tensor)
param = torch.nn.Parameter(orig_copy)
ssd_param = so.SsdParameter(orig_tensor.shape, orig_tensor.dtype)
ssd_param.point_to_tensor(orig_copy)
ssd_param.flush_on_dirty = False
ssd_param.set_file_params(f.name, 0)
ssd_param.to_file(release_tensor_after_write=True)
assert torch.equal(ssd_param.to_tensor(), orig_tensor)
optimizer_ssd = torch.optim.SGD([ssd_param], lr=0.1)
optimizer_orig = torch.optim.SGD([param], lr=0.1)
y1 = ssd_param + 1
optimizer_ssd.zero_grad()
y1.sum().backward()
# Test to see if Dirty is being calculated correctly when optimizer modifies
# ssd_param
assert ssd_param.storage_state is so.StorageState.ON_CPU_CLEAN
optimizer_ssd.step()
assert ssd_param.storage_state is so.StorageState.ON_CPU_DIRTY
y2 = param + 1
optimizer_orig.zero_grad()
y2.sum().backward()
optimizer_orig.step()
assert torch.equal(ssd_param.to_tensor(), param)
def test_ssd_flat_parameter_basic():
_init()
with tempfile.NamedTemporaryFile() as f:
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refc_param = torch.nn.Parameter(torch.rand(128, dtype=torch.float32))
ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], direct_to_file=False)
ssd_flat_param.set_file_params(f.name, 0)
param_views = list(ssd_flat_param.get_param_views())
assert refa_param.shape == param_views[0].shape
assert refb_param.shape == param_views[1].shape
assert refc_param.shape == param_views[2].shape
assert torch.equal(refa_param, param_views[0])
assert torch.equal(refb_param, param_views[1])
assert torch.equal(refc_param, param_views[2])
ssd_flat_param.to_file()
assert not ssd_flat_param.is_available()
first_value = param_views[0][0][0].item()
assert ssd_flat_param.is_available()
assert first_value == refa_param[0][0].item()
def test_ssd_flat_parameter_view_modify():
_init()
with tempfile.NamedTemporaryFile() as f:
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=False)
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=False)
refc_param = torch.nn.Parameter(torch.rand(128, dtype=torch.float32), requires_grad=False)
ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], direct_to_file=False)
ssd_flat_param.set_file_params(f.name, 0)
ssd_flat_param.flush_on_dirty = False
param_views = list(ssd_flat_param.get_param_views())
assert ssd_flat_param.storage_state == so.StorageState.ON_CPU_DIRTY
ssd_flat_param.to_file()
assert ssd_flat_param.storage_state == so.StorageState.ON_DISK
assert param_views[0].tensor is None
param_views[0] += 0.1
assert ssd_flat_param.storage_state == so.StorageState.ON_CPU_DIRTY
@pytest.mark.skip("broken at head")
def test_ssd_flat_parameter_view_bwd():
_init()
hooks_called = []
def post_backward_hook(name, hooks_called, *grads):
print(f"BACKWARD HOOK for tensor {name} CALLED")
hooks_called.append(name)
with tempfile.NamedTemporaryFile() as f:
refa_param = (
torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=True)
.to("cpu")
.detach()
.requires_grad_()
)
refb_param = (
torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=True)
.to("cpu")
.detach()
.requires_grad_()
)
refc_param = (
torch.nn.Parameter(torch.rand(128, dtype=torch.float32), requires_grad=True)
.to("cpu")
.detach()
.requires_grad_()
)
ssd_flat_param = so.SsdFlatParameter.from_tensors(
[refa_param, refb_param, refc_param], direct_to_file=True, filename=f.name, offset=0
)
orig_copy = ssd_flat_param.data
cuda_copy = ssd_flat_param.to("cuda").detach().requires_grad_()
cpu_copy = ssd_flat_param.to("cpu").detach().requires_grad_()
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_orig", hooks_called))
ssd_flat_param.data = cuda_copy
one = torch.ones(1, requires_grad=True, device=ssd_flat_param.device)
y1 = ssd_flat_param.views[0] + one
y2 = cuda_copy + 1
# ssd_flat_param.to_file()
# ssd_flat_param.data = orig_copy
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_cuda", hooks_called))
ssd_flat_param.views[0].register_hook(
functools.partial(post_backward_hook, "ssd_flat_param.views[0]", hooks_called)
)
ssd_flat_param.register_hook(functools.partial(post_backward_hook, "ssd_flat_param", hooks_called))
one.register_hook(functools.partial(post_backward_hook, "one", hooks_called))
y1.sum().backward()
y2.sum().backward()
assert "GradAccumulation_cuda" in hooks_called
assert "ssd_flat_param.views[0]" in hooks_called
assert "ssd_flat_param" in hooks_called
assert "one" in hooks_called
@pytest.mark.skip("broken at head")
def test_ssd_flat_parameter_view_bwd_parameterization():
_init()
hooks_called = []
def post_backward_hook(name, hooks_called, *grads):
print(f"BACKWARD HOOK for tensor {name} CALLED")
hooks_called.append(name)
with tempfile.NamedTemporaryFile() as f:
layer1 = torch.nn.Linear(32, 4, bias=False)
layer2 = torch.nn.Linear(32, 4, bias=False)
layer3 = torch.nn.Linear(128, 1, bias=False)
ssd_flat_param = so.SsdFlatParameter.from_tensors(
[layer1.weight, layer2.weight, layer3.weight], direct_to_file=False, filename=f.name, offset=0
)
torch.nn.utils.parametrize.register_parametrization(
layer1, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 0)
)
torch.nn.utils.parametrize.register_parametrization(
layer2, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 1)
)
torch.nn.utils.parametrize.register_parametrization(
layer3, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 2)
)
orig_copy = ssd_flat_param.data
cuda_copy = ssd_flat_param.to("cuda").detach().requires_grad_()
cpu_copy = ssd_flat_param.to("cpu").detach().requires_grad_()
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_orig", hooks_called))
ssd_flat_param.to_file(release_tensor_after_write=False)
ssd_flat_param.data = cuda_copy
one = torch.ones(layer1.weight.shape, requires_grad=True, device=ssd_flat_param.device)
y1 = layer1.forward(one)
y2 = cuda_copy + 1
# ssd_flat_param.to_file()
# ssd_flat_param.data = orig_copy
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_cuda", hooks_called))
ssd_flat_param.views[0].register_hook(
functools.partial(post_backward_hook, "ssd_flat_param.views[0]", hooks_called)
)
ssd_flat_param.register_hook(functools.partial(post_backward_hook, "ssd_flat_param", hooks_called))
one.register_hook(functools.partial(post_backward_hook, "one", hooks_called))
y1.sum().backward()
y2.sum().backward()
assert "GradAccumulation_cuda" in hooks_called
assert "ssd_flat_param.views[0]" in hooks_called
assert "ssd_flat_param" in hooks_called
assert "one" in hooks_called
def test_ssd_flat_parameter_direct_to_file():
_init()
with tempfile.NamedTemporaryFile() as f:
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refc_param = torch.nn.Parameter(torch.rand(128, dtype=torch.float32))
ssd_flat_param = so.SsdFlatParameter.from_tensors(
[refa_param, refb_param, refc_param], direct_to_file=True, filename=f.name, offset=0
)
param_views = list(ssd_flat_param.get_param_views())
assert refa_param.shape == param_views[0].shape
assert refb_param.shape == param_views[1].shape
assert refc_param.shape == param_views[2].shape
assert torch.equal(refa_param, param_views[0])
assert torch.equal(refb_param, param_views[1])
assert torch.equal(refc_param, param_views[2])
ssd_flat_param.to_file()
assert not ssd_flat_param.is_available()
first_value = param_views[0][0][0].item()
assert ssd_flat_param.is_available()
assert first_value == refa_param[0][0].item()
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment