"vscode:/vscode.git/clone" did not exist on "fd11b874cfdeabcf31eba0e24fa4ecebf5410e79"
Unverified Commit e71d2570 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[cleanup] remove ssd offload to simplify the FSDP code (#1080)



* simlificed the readme

* clean up ssd offload

* try to fix readthedocs
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent f4fcee7e
# .readthedocs.yaml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
# Required
version: 2
# We need python > 3.8 due to a dependency on numpy.
build:
os: ubuntu-20.04
tools:
python: "3.9"
# You can also specify other tool versions:
# nodejs: "16"
# rust: "1.55"
# golang: "1.17"
# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/source/conf.py
# If using Sphinx, optionally build your docs in additional formats such as PDF
# formats:
# - pdf
# Optionally declare the Python requirements required to build your docs
python:
install:
- requirements: docs/requirements.txt
......@@ -25,23 +25,6 @@ FairScale was designed with the following values in mind:
[![Explain Like I’m 5: FairScale](https://img.youtube.com/vi/oDt7ebOwWIc/0.jpg)](https://www.youtube.com/watch?v=oDt7ebOwWIc)
## What's New:
* March 2022 [fairscale 0.4.6 was released](https://github.com/facebookresearch/fairscale/releases/tag/v0.4.6).
* We have support for CosFace's LMCL in MEVO. This is a loss function that is suitable for large number of prediction target classes.
* January 2022 [fairscale 0.4.5 was released](https://github.com/facebookresearch/fairscale/releases/tag/v0.4.5).
* We have experimental support for layer wise gradient scaling.
* We enabled reduce_scatter operation overlapping in FSDP backward propagation.
* December 2021 [fairscale 0.4.4 was released](https://github.com/facebookresearch/fairscale/releases/tag/v0.4.4).
* FairScale is tested with the following PyTorch versions (with CUDA 11.2): 1.8.1, 1.10.0 and 1.11.0.dev20211101+cu111.
* November 2021 [fairscale 0.4.3 was released](https://github.com/facebookresearch/fairscale/releases/tag/v0.4.3).
* We have experimental support for offloading params to disk when using the FSDP API for evaluation workloads.
* We have an experimental layer that fuses multiple layers together to support large vocab size trainings.
* November 2021 [fairscale 0.4.2 was released](https://github.com/facebookresearch/fairscale/releases/tag/v0.4.2).
* We have a new experimental API called the LayerwiseMemoryTracker to help track, visualize and suggest fixes for memory issues occurring during the forward/backward pass of your models.
* Introducing SlowMoDistributedDataParallel API, a distributed training wrapper that is useful on clusters with slow network interconnects (e.g. Ethernet).
* September 2021 [`master` branch renamed to `main`](https://github.com/github/renaming).
## Installation
To install FairScale, please see the following [instructions](https://github.com/facebookresearch/fairscale/blob/main/docs/source/installation_instructions.rst).
......@@ -50,134 +33,26 @@ You should be able to install a package with pip or conda, or build directly fro
## Getting Started
The full [documentation](https://fairscale.readthedocs.io/) contains instructions for getting started, deep dives and tutorials about the various FairScale APIs.
## Examples
Here are a few sample snippets from a subset of FairScale offerings:
### Pipe
Run a 4-layer model on 2 GPUs. The first two layers run on cuda:0 and the next two layers run on cuda:1.
```python
import torch
import fairscale
model = torch.nn.Sequential(a, b, c, d)
model = fairscale.nn.Pipe(model, balance=[2, 2], devices=[0, 1], chunks=8)
```
### Optimizer state sharding (ZeRO)
See a more complete example [here](https://github.com/facebookresearch/fairscale/blob/main/benchmarks/oss.py), but a minimal example could look like the following :
```python
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from fairscale.optim.oss import OSS
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
def train(
rank: int,
world_size: int,
epochs: int):
# DDP init example
dist.init_process_group(backend='nccl', init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
# Problem statement
model = myAwesomeModel().to(rank)
dataloader = mySuperFastDataloader()
loss_fn = myVeryRelevantLoss()
base_optimizer = torch.optim.SGD # pick any pytorch compliant optimizer here
base_optimizer_arguments = {} # pass any optimizer specific arguments here, or directly below when instantiating OSS
# Wrap the optimizer in its state sharding brethren
optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments)
# Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks
model = ShardedDDP(model, optimizer)
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
for e in range(epochs):
for batch in dataloader:
# Train
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
loss.backward()
optimizer.step()
dist.destroy_process_group()
if __name__ == "__main__":
# Supposing that WORLD_SIZE and EPOCHS are somehow defined somewhere
mp.spawn(
train,
args=(
WORLD_SIZE,
EPOCHS,
),
nprocs=WORLD_SIZE,
join=True,
)
```
### AdaScale SGD
AdaScale can be used to wrap a SGD optimizer and to be used in DDP (Distributed Data Parallel)
training or non-DDP with gradient accumulation. The benefit is to re-use the same LR
schedule from a baseline batch size when effective batch size is bigger.
Note that AdaScale does _not_ help increase per-GPU batch size.
```python
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR # or your scheduler
from fairscale.optim import AdaScale
...
optim = AdaScale(SGD(model.parameters(), lr=0.1))
scheduler = LambdaLR(optim, ...)
...
# Note: the train loop should be with DDP or with gradient accumulation.
last_epoch = 0
step = 0
done = False
while not done:
for sample in dataset:
...
step += optim.gain()
optim.step()
epoch = step // len(dataset)
if last_epoch != epoch:
scheduler.step()
last_epoch = epoch
if epoch > max_epoch:
done = True
```
## FSDP
Primary goal is to allow scaling to bigger batch sizes without losing model accuracy.
(However, training time might be longer comparing to without AdaScale.)
At a high level, we want ML researchers to:
* go parallel more easily (i.e. no need to find new learning rate schedules)
* not worrying about losing accuracy
* potentially higher GPU efficiency (fewer steps, less networking overhead, etc.)
FullyShardedDataParallel (FSDP) is the recommended method for scaling to large NN models.
This library has been [upstreamed to PyTorch](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/).
The version of FSDP here is for historical references as well as for experimenting with
new and crazy ideas in research of scaling techniques. Please see the following blog
for [how to use FairScale FSDP and how does it work](https://engineering.fb.com/2021/07/15/open-source/fsdp/).
## Testing
We use circleci to test FairScale with the following PyTorch versions (with CUDA 11.2):
* the latest stable release (1.10.0)
* the latest LTS release (1.8.1)
* a recent nightly release (1.11.0.dev20211101+cu111)
* the latest stable release (e.g. 1.10.0)
* the latest LTS release (e.g. 1.8.1)
* a recent nightly release (e.g. 1.11.0.dev20211101+cu111)
Please create an [issue](https://github.com/facebookresearch/fairscale/issues) if you are having trouble with installation.
## Contributors
We welcome outside contributions! Please see the [CONTRIBUTING](CONTRIBUTING.md) instructions for how you can contribute to FairScale.
We welcome contributions! Please see the [CONTRIBUTING](CONTRIBUTING.md) instructions for how you can contribute to FairScale.
## License
......@@ -198,22 +73,9 @@ If you use FairScale in your publication, please cite it by using the following
```BibTeX
@Misc{FairScale2021,
author = {Mandeep Baines and Shruti Bhosale and Vittorio Caggiano and Naman Goyal and Siddharth Goyal and Myle Ott and Benjamin Lefaudeux and Vitaliy Liptchinsky and Mike Rabbat and Sam Sheiffer and Anjali Sridhar and Min Xu},
author = {FairScale authors},
title = {FairScale: A general purpose modular PyTorch library for high performance and large scale training},
howpublished = {\url{https://github.com/facebookresearch/fairscale}},
year = {2021}
}
```
## FAQ
1. If you experience an error indicating a default branch does not exist, it probably due to the latest update, switching the default branch from "master" to "main"
```
error: pathspec 'non-existing-branch' did not match any file(s) known to git
```
Please run the following commands to update to the main branch.
```
git branch -m master main
git fetch origin
git branch -u origin/main main
git remote set-head origin -a
```
......@@ -25,7 +25,6 @@ from torch.optim import Adam
from benchmarks.golden_configs.lm_wikitext2 import FSDP as lm_wikitext2
from fairscale.nn import auto_wrap, default_auto_wrap_policy, enable_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import OffloadConfig
RPC_PORT = 29501
......@@ -95,10 +94,7 @@ def get_lm_model(args, device, config):
nhid = config["nhid"]
ndecoder = config["num_decoder_layers"]
if args.ssd_offload:
return transformer_lm.TransformerLM(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder)
else:
return transformer_lm.TransformerLM(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
return transformer_lm.TransformerLM(vocab_size, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
def get_tensors_by_size_bucket():
......@@ -200,7 +196,7 @@ def train(model_config, model, benchmark_config, model_specs, args):
if i > 0:
total_tokens += source.numel()
if args.benchmark_eval or args.ssd_offload:
if args.benchmark_eval:
input = source.cuda()
target = target.cuda()
output = model(input)
......@@ -250,7 +246,6 @@ def get_number_of_words(data):
def benchmark_language_model(model_config, model, benchmark_config, model_specs, args):
# TODO(anj): Uncomment and add a check for regression once we have a couple of runs.
golden_config = get_golden_config(args.model_name, args)
epoch = benchmark_config["epochs"]
start_time = time.time()
......@@ -358,8 +353,6 @@ def benchmark_fsdp(rank, args, world_size):
model_config = create_model_config(args, benchmark_config=benchmark_config, model_specs=model_specs)
model = model_config["model"]
config = {}
if args.ssd_offload:
config["offload_config"] = OffloadConfig(offload_type="ssd_offload")
if args.full_fp16:
config["compute_dtype"] = torch.float16
......@@ -386,7 +379,6 @@ parser.add_argument("--max_batch", type=int, default=4, help="Max number of batc
parser.add_argument("--use_synthetic_data", action="store_true", help="Uses synthetic data for running benchmarks.")
parser.add_argument("--dry_run", action="store_true", help="Run a sample training run without regression testing.")
parser.add_argument(
# TODO(anj-s): In the process of adding more models and hence the requirement for a flag.
"--model_name",
default="lm",
help="Language Model(LM) used to benchmark FSDP.",
......@@ -394,7 +386,6 @@ parser.add_argument(
parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
parser.add_argument("--enable_auto_wrap", action="store_true", default=False, help="Use auto_wrap with FSDP")
parser.add_argument("--benchmark_eval", action="store_true", default=False, help="Benchmark evaluation workflow.")
parser.add_argument("--ssd_offload", action="store_true", default=False, help="Benchmark ssd_offload workflow.")
parser.add_argument("--full_fp16", action="store_true", default=False, help="Benchmark in full fp16 mode.")
if __name__ == "__main__":
......
......@@ -30,7 +30,7 @@ sys.path.insert(0, os.path.abspath("../.."))
# -- Project information -----------------------------------------------------
project = "FairScale"
copyright = "2020-2021, Facebook/Meta AI Research"
copyright = "2020-2022, Facebook/Meta AI Research"
author = "Facebook/Meta AI Research"
# -- General configuration ---------------------------------------------------
......@@ -68,7 +68,7 @@ autodoc_inherit_docstrings = False
autodoc_member_order = "bysource"
intersphinx_mapping = {
"python": ("https://docs.python.org/3.6", None),
"python": ("https://docs.python.org/3.8", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"torch": ("https://pytorch.org/docs/stable/", None),
}
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from enum import Enum, auto
from functools import reduce
import io
import os
import pickle
from types import TracebackType
from typing import IO, Any, BinaryIO, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Type, Union
import numpy as np
import torch
from torch.serialization import DEFAULT_PROTOCOL as DEFAULT_PROTOCOL
from fairscale.internal import torch_version
try:
from torch.utils._pytree import tree_map
except ImportError:
# The PyTorch version(<1.9) we test with does not support the tree_map API.
pass
if torch_version() < (1, 12, 0):
raise ImportError(
f"ssd_offload only works on torch versions 1.12.0 and beyond, but torch version is: {torch.__version__}"
)
DEFAULT_CHUNK_SIZE = 2048 * 2048
def _get_num_chunks(input_tensor: torch.Tensor, chunk_size_bytes: int = DEFAULT_CHUNK_SIZE) -> int:
"""Returns the number of chunks that the given tensor can be divided into."""
size_in_bytes = input_tensor.nelement() * input_tensor.element_size()
num_chunks = (size_in_bytes + (chunk_size_bytes - 1)) // chunk_size_bytes
return num_chunks
def _tensor_to_bytes_chunks(
input_tensor: torch.Tensor, chunk_idx: int, chunk_size_bytes: int = DEFAULT_CHUNK_SIZE
) -> bytes:
"""Converts the given tensor into a chunked array containing chunk_size_bytes."""
size_in_bytes = input_tensor.nelement() * input_tensor.element_size()
assert chunk_idx < _get_num_chunks(input_tensor, chunk_size_bytes)
input_tensor_np = input_tensor.detach().numpy().view(np.uint8).reshape(-1)
chunk_start = chunk_idx * chunk_size_bytes
chunk_end = min(size_in_bytes, chunk_start + chunk_size_bytes)
return input_tensor_np[chunk_start:chunk_end].tobytes()
def write(input_tensor: torch.Tensor, filename: str, file_offset_bytes: int = 0) -> None:
"""Populates the file with the data stored in the given tensor."""
num_chunks = _get_num_chunks(input_tensor)
file_flags = "r+b" if os.path.exists(filename) else "wb"
with open(filename, file_flags) as f:
f.seek(file_offset_bytes)
for i in range(num_chunks):
f.write(_tensor_to_bytes_chunks(input_tensor, i))
def read(input_tensor: torch.Tensor, filename: str, file_offset_bytes: int = 0) -> None:
"""Populates the given tensor with the data stored in a file."""
size_in_bytes = input_tensor.nelement() * input_tensor.element_size()
chunk_size_bytes = DEFAULT_CHUNK_SIZE
num_chunks = _get_num_chunks(input_tensor)
input_tensor_np = input_tensor.detach().numpy()
input_tensor_mv = memoryview(input_tensor_np.view(dtype=np.uint8).reshape(-1))
with io.open(filename, "rb") as f:
f.seek(file_offset_bytes)
for i in range(num_chunks):
chunk_start = i * chunk_size_bytes
chunk_end = min(size_in_bytes, chunk_start + chunk_size_bytes)
data_read = f.readinto(input_tensor_mv[chunk_start:chunk_end])
if data_read != chunk_end - chunk_start:
raise RuntimeError(
f"Attempted to read {chunk_end - chunk_start} more bytes from {filename}, but only read: {data_read} bytes. Total Bytes read = {chunk_start + data_read}, total bytes expected: {size_in_bytes}"
)
class StorageState(Enum):
"""
Simple enum to indicate whether the tensor handle is pointing
to data on disk or memory. This is useful for asserting on
whether the tensor is available for operations or if it needs
to be moved from disk to CPU or device.
"""
UNALLOCATED = auto()
ON_DISK = auto()
ON_CPU_CLEAN = auto()
ON_CPU_DIRTY = auto()
class SsdTensorHandle(torch.Tensor):
"""
This class extends from torch.Tensor and represents a Tensor which is backed by SSD storage.
The SsdTensorHandle object can point to a file or a tensor and there are corresponding functions to read
data into the tensor that is an attribute of the SsdTensorHandle object or write the tensor to file. At any
point in time the Tensor may be in memory or on disk.
Class Variables:
override_directory_path: This variable is used by CheckpointPathContextManager to modify the path to any
SsdTensorHandles that are saved to a checkpoint via pickling (e.g. torch.save)
Args:
shape torch.Size: Shape of the tensor that is represented by the handle.
dtype: torch.dtype: Dtype of the tensor that is represented by the handle.
requires_grad: bool: Property of the tensor that is represeneted by the handle.
Returns:
A SSDTensorHandle object representing a Tensor.
"""
override_directory_path: Optional[str] = None
@staticmethod
def __new__(
cls: Type[SsdTensorHandle],
shape: torch.Size,
dtype: torch.dtype,
requires_grad: bool = False,
device: torch.device = torch.device("cpu"),
flush_on_dirty: bool = True,
allow_unsafe_changes: bool = False,
) -> SsdTensorHandle:
r = super(SsdTensorHandle, cls)._make_wrapper_subclass(cls, shape, dtype=dtype, requires_grad=requires_grad, device=device) # type: ignore
return r
def __init__(
self,
shape: torch.Size,
dtype: torch.dtype,
requires_grad: bool,
device: torch.device = torch.device("cpu"),
flush_on_dirty: bool = True,
allow_unsafe_changes: bool = False,
) -> None:
self._unpickle_f: Optional[Union[BinaryIO, IO[bytes]]] = None
self._shape = shape
if len(shape) == 0:
self._numel = 0
else:
self._numel = reduce((lambda x, y: x * y), shape)
self._dtype = dtype
# valid if offloaded to file
self.filename = ""
self.offset = -1
# valid if loaded to memory
self.tensor: Optional[torch.Tensor] = None
self.storage_state = StorageState.UNALLOCATED
self.flush_on_dirty = flush_on_dirty
self.allow_unsafe_changes = allow_unsafe_changes
def mark_dirty(self) -> None:
assert self.tensor is not None
assert self.storage_state in [StorageState.ON_CPU_CLEAN, StorageState.ON_CPU_DIRTY]
self.storage_state = StorageState.ON_CPU_DIRTY
# hack to force write on mark_dirty
if self.flush_on_dirty:
self.to_file()
@classmethod
def from_file(
cls, shape: torch.Size, dtype: torch.dtype, filename: str, offset: int = 0, requires_grad: bool = False
) -> SsdTensorHandle:
"""Returns a new SsdTensorHandle from a file."""
handle = cls(shape=shape, dtype=dtype, requires_grad=requires_grad)
handle.point_to_file(filename, offset=offset)
return handle
@classmethod
def from_tensor(cls: Type[SsdTensorHandle], tensor: torch.Tensor) -> SsdTensorHandle:
"""Returns a new SsdTensorHandle from a tensor."""
handle = cls(shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad, device=tensor.device)
handle.point_to_tensor(tensor)
return handle
def is_available(self) -> bool:
return self.tensor is not None
def get_tensor(self) -> torch.Tensor:
assert self.tensor is not None
return self.tensor
def set_file_params(self, filename: str, offset: int) -> None:
self.filename = filename
self.offset = offset
def point_to_file(self, filename: str, offset: int) -> None:
self.set_file_params(filename, offset)
self.tensor = None
self.storage_state = StorageState.ON_DISK
def point_to_tensor(self, tensor: torch.Tensor) -> None:
assert self.tensor is None
if not self.allow_unsafe_changes:
assert self._shape == tensor.shape
assert self._dtype == tensor.dtype
self.tensor = tensor
self.storage_state = StorageState.ON_CPU_DIRTY
# if resizing a handle that is part of an ssd buffer, care must be taken that the new size
# doesn't conflict with adjacent handles!
def point_to_resized_tensor(self, tensor: torch.Tensor) -> None:
assert self._dtype == tensor.dtype
self._shape = tensor.shape
self.tensor = tensor
def to_tensor(self) -> torch.Tensor:
"""Returns the tensor represented by the SsdTensorHandle object.
If the tensor is on disk, it is copied into the tensor attribute and returned.
"""
if self.tensor is not None:
return self.tensor
else:
if self.device != torch.device("cpu"):
raise RuntimeError(
f"to_tensor called on an SsdTensorHandle when the tensor has been offloaded to disk. self.device = {self.device}, it should be {torch.device('cpu')}. Some unexpected .data override has occured!!"
)
result_tensor = torch.empty(size=self.shape, dtype=self.dtype, requires_grad=self.requires_grad)
self.copy_into_tensor(result_tensor)
self.tensor = result_tensor
self.storage_state = StorageState.ON_CPU_CLEAN
return self.tensor
def to_file(self, permit_when_tensor_none: bool = False, release_tensor_after_write: bool = True) -> None:
"""Saves the tensor to disk and releases memory if specified."""
assert self.tensor is not None or permit_when_tensor_none
# if it's available in Memory but not modified, no need to write-back
if self.tensor is not None:
if self.storage_state is StorageState.ON_CPU_DIRTY:
if self.device != torch.device("cpu"):
raise RuntimeError(
f"to_file called on an SsdTensorHandle when self.device = {self.device}, it should be {torch.device('cpu')}. Some unexpected .data override has occured!!"
)
write(self.tensor, self.filename, self.offset * self.tensor.element_size())
if release_tensor_after_write:
self.tensor = None
self.storage_state = StorageState.ON_DISK
else:
self.storage_state = StorageState.ON_CPU_CLEAN
def copy_into_tensor(self, tensor: torch.Tensor) -> None:
"""Copies SsdTensorHandle's data into the given tensor.
If the tensor is in memory, this function copies the data
into the passed in tensor. Otherwise, it reads from file into tensor,
using the read() function.
This does not modify modify self.tensor unlike the to_tensor()
function. This can be useful for calls like named_parameters() when
the tensor is already offloaded to disk.
"""
# ideally this should be checked but .data shenanigans forces it to
# be disabled due to the way FSDP shards parameters
# assert self._shape == tensor.shape
assert self._dtype == tensor.dtype
if self.tensor is not None:
tensor.copy_(self.tensor)
else:
read(tensor, self.filename, self.offset * tensor.element_size())
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore
"""Intercepts all operations performed on this handle object.
Before any operation, the tensor attribute is unwrapped from the handle
and used in the operation. We maintain a refernce to the tensor and its current
versions to track if modifications have been made. If we detect changes to the
tensor, we write it to the file maintained by the Handle.
"""
func_name = func.overloadpacket.__name__
ssd_tensor_handles = []
def unwrap(e: Any) -> torch.Tensor:
if isinstance(e, SsdTensorHandle):
t = e.to_tensor()
ssd_tensor_handles.append(e)
return t
else:
return e
r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
for e in ssd_tensor_handles:
inplace_is_this_tensor = (
(func_name.endswith("_") and not func_name.endswith("__")) or func_name.startswith("__i")
) and e is args[0]
out_is_this_tensor = False if "out" not in kwargs else e is kwargs["out"]
if inplace_is_this_tensor or out_is_this_tensor:
e.mark_dirty()
return r
def __setattr__(self, name: str, value: Any) -> None:
if name == "data":
assert isinstance(value, torch.Tensor)
if not self.allow_unsafe_changes:
# Respect .data changes, and the user better know what they are doing!
if self.storage_state == StorageState.ON_CPU_DIRTY:
raise RuntimeError(
"Attempting to override tensor when the existing tensor is dirty, this is an error!"
)
if value.shape != self.shape:
raise RuntimeError(
f"Attempting to override tensor metadata using .data to change shape of tensor. Orig shape: {self.shape} New shape: {value.shape}"
)
if value.requires_grad != self.requires_grad:
raise RuntimeError(
f"Attempting to override tensor metadata using .data to change requires_grad. Orig value: {self.requires_grad} New value: {value.requires_grad}"
)
self.tensor = value
super(SsdTensorHandle, self).__setattr__(name, value)
@classmethod
def __unpickle__(
cls: Type[SsdTensorHandle], shape: torch.Size, dtype: torch.dtype, requires_grad: bool, filename: str
) -> SsdTensorHandle:
result = cls(shape, dtype, requires_grad)
result.point_to_file(filename, 0)
result._unpickle_f = io.open(result.filename, "wb")
return result
def __reduce_ex__(self, proto: int) -> Tuple[Any, Any, Any, Any]:
byte_iter = None
filename = self.filename
if self.override_directory_path is not None:
head, tail = os.path.split(self.filename)
filename = os.path.join(self.override_directory_path, tail)
if self.is_available():
byte_iter = iter(TensorChunkingIterator(self.tensor)) # ignore: type
else:
byte_iter = iter(
FileChunkingIterator(self.filename, expected_size_bytes=self.numel() * self.element_size())
)
return (
self.__unpickle__, # Callable
# Args to the callable above
(self._shape, self._dtype, self.requires_grad, filename),
None,
byte_iter,
)
def append(self, item: bytes) -> None:
assert self._unpickle_f
self._unpickle_f.write(item)
def extend(self, items: List[bytes]) -> None:
for i in items:
self.append(i)
class CheckpointPathContextManager:
"""
This Context allows the user to override the directory path when pickling an SsdTensorHandle Object.
It is needed because the filename which the SsdTensorHandle points to (and is used when unpickling)
is already baked into the pickled data.
Consider the following example code
ssd_handle = SsdTensorHandle.from_tensor(ref_tensor)
ssd_handle.set_file_params('/home/user/handle.bin', 0)
torch.save(ssd_handle, '/home/user/checkpoint.pkl')
ssd_handle += 1
ssd_handle.to_file()
ssd_handle2 = torch.load('/home/user/checkpoint.pkl')
print(f"handles are equal: {torch.equals(ssd_handle, ssd_handle2)}")
One would expect this to print False, however unintuitively it will print True.
ssd_handle.filename and ssd_handle2.filename are equal. This means that
when we execute torch.load, we read from the .pkl file and write the result into
/home/user/handle.bin, clobbering the updated result from `ssd_handle += 1`
We want to give the user the possibility of not clobbering the data using this
Context Manager.
ssd_handle = SsdTensorHandle.from_tensor(ref_tensor)
ssd_handle.set_file_params('/home/user/handle.bin', 0)
with CheckpointPathContextManager(override_path='/home/user/checkpoint_data/'):
torch.save(ssd_handle, '/home/user/checkpoint.pkl')
ssd_handle += 1
ssd_handle.to_file()
ssd_handle2 = torch.load('/home/user/checkpoint.pkl')
print(f"handles are equal: {torch.equals(ssd_handle, ssd_handle2)}")
This code results with ssd_handle.filename = '/home/user/handle.bin' and ssd_handle2.filename =
`/home/user/checkpoint_data/handle.bin'. Therefore the torch.load won't clobber ssd_handle, and
the printed result is False.
"""
def __init__(self, override_path: str) -> None:
self.old_path = SsdTensorHandle.override_directory_path
self.override_path = override_path
def __enter__(self) -> None:
SsdTensorHandle.override_directory_path = self.override_path
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
exec_traceback: Optional[TracebackType],
) -> None:
SsdTensorHandle.override_directory_path = self.old_path
# Classes supporting torch.save/load
class TorchSaver:
def __init__(self) -> None:
self.pickle_module = DisableMemoizationPicklerModule
def save(
self, obj: Any, f: Union[str, os.PathLike, BinaryIO, IO[bytes]], pickle_protocol: int = DEFAULT_PROTOCOL
) -> None:
torch.serialization.save(
obj, f, self.pickle_module, pickle_protocol=pickle_protocol, _use_new_zipfile_serialization=False
)
class SsdParameter(SsdTensorHandle, torch.nn.Parameter):
@classmethod
def from_tensor(cls: Type[SsdParameter], tensor: SsdTensorHandle) -> SsdParameter: # type: ignore
r = cls(tensor.shape, tensor.dtype, tensor.requires_grad, device=tensor.device)
r.point_to_tensor(tensor)
return r
@staticmethod
def __new__(
cls: Type[SsdParameter],
shape: torch.Size,
dtype: torch.dtype,
requires_grad: bool = True,
device: torch.device = torch.device("cpu"),
) -> SsdParameter:
r = super(SsdParameter, cls).__new__(cls, shape=shape, dtype=dtype, requires_grad=requires_grad, device=device)
return r # type: ignore
def __init__(
self,
shape: torch.Size,
dtype: torch.dtype,
requires_grad: bool = True,
device: torch.device = torch.device("cpu"),
) -> None:
super(SsdParameter, self).__init__(shape=shape, dtype=dtype, requires_grad=requires_grad, device=device)
class SsdFlatParameter(SsdParameter):
"""A parameter that is initialized from a list of parameters and can be
turned into a list of views as needed.
This class should eventually be moved to fairscale/nn/misc/flatten_params_wrapper.py
"""
def __new__(
cls: Type[SsdFlatParameter],
shapes: Sequence[torch.Size],
dtype: torch.dtype,
requires_grad: bool = True,
device: torch.device = torch.device("cpu"),
) -> SsdFlatParameter:
"""Make an object using the parent's __new__ function."""
# A empty of non-list input doesn't make sense.
if not isinstance(shapes, (list, tuple)) or len(shapes) == 0:
raise ValueError("An non-empty list or tuple argument is needed")
size = sum([np.prod(s) for s in shapes])
r = super(SsdFlatParameter, cls).__new__(
cls, torch.Size((size,)), dtype=dtype, requires_grad=requires_grad, device=device
)
return r # type: ignore
def __init__(
self,
shapes: Sequence[torch.Size],
dtype: torch.dtype,
requires_grad: bool = True,
device: torch.device = torch.device("cpu"),
):
"""Initialize the _param_numels and _param_shapes lists."""
self._param_shapes = shapes
self._param_numels = [np.prod(s) for s in shapes]
total_numels = sum(self._param_numels)
assert (
self.numel() <= total_numels
), f"Something wrong with __new__ method, {self.numel()} vs. {sum(self._param_numels)}"
self.views: List[SsdFlatParameterView] = []
# These are set by FPW class below, not by this class itself.
self._param_infos: List[Tuple[str, torch.nn.Module, str]] = []
self._shared_param_infos: List[Tuple[str, str, torch.nn.Module, str, torch.nn.Module, str]] = []
super(SsdFlatParameter, self).__init__(
shape=torch.Size((total_numels,)), dtype=dtype, requires_grad=requires_grad
)
def __setattr__(self, name: str, value: Any) -> None:
super(SsdFlatParameter, self).__setattr__(name, value)
if name == "data":
# if .data has changed, we need to totally destroy any existing views because things
# like device might have changed. It won't destroy any pointers to those views outside
# of here, however resetting self.views will trigger the old view's assertion in
# __torch_dispatch__ that it is the current view of it's parent object
self.views = []
self._refresh_views()
def _invalidate_views(self) -> None:
for v in self.views:
v.tensor = None
@torch.enable_grad()
def _refresh_views(self) -> None:
if self._shape != self.shape:
self.views = []
return
if len(self.views) == 0:
self.views = [s.view(v) for s, v in zip(self.split(self._param_numels), self._param_shapes)] # type: ignore
else:
for v, t, s in zip(self.views, self.tensor.split(self._param_numels), self._param_shapes):
v.tensor = t.view(s)
def get_param_views(self, external_data: Optional[torch.Tensor] = None) -> Iterator[torch.Tensor]:
"""Return a generator of views that map to the original parameters."""
# Note, self.data could be sharded, so its numel is <= to the sum.
"""
assert self.data.numel() <= sum(
self._param_numels
), f"Incorrect internal state {self.data.numel()} vs. {sum(self._param_numels)}"
"""
if external_data is not None:
if external_data.numel() != sum(self._param_numels):
raise ValueError(
f"Incorrect numel of supplied data: got {external_data.numel()} but expected {sum(self._param_numels)}"
)
return (t.view(s) for (t, s) in zip(external_data.split(self._param_numels), self._param_shapes))
else:
# this needs to return SsdFlatParameterViews
if not self.is_available():
self.to_tensor()
if len(self.views) == 0:
raise RuntimeError(
"Trying to call get_param_views when self.views is empty, this means that .data games have been played and the current .data shape doesn't match the constructed shape."
)
return (v for v in self.views)
def metadata(self) -> Tuple[List[str], Sequence[torch.Size], List[int]]:
"""Return tuple of (names, shapes, numels) metadata for this flat parameter."""
names = [".".join([m, n]) if m else n for (m, _, n) in self._param_infos]
return names, self._param_shapes, self._param_numels
@classmethod
def from_tensors(
cls: Type[SsdFlatParameter],
tensors: Sequence[torch.Tensor],
direct_to_file: bool = False,
filename: str = "",
offset: int = 0,
) -> "SsdFlatParameter":
"""Returns a new SsdFlatParameter from a sequence of tensors."""
assert (
len(tensors) > 0
), "SsdFlatParameter.from_tensors must be called with at least one tensor in the tensors argument"
# Flattening involves (1) making a tensor flat (i.e. single dimensional) and (2) making a module
# heirarchy flat (using a single tensor to replace a tree of tensors). Therefore,
# adding back nesting and heirarchy is counter-productive. If nesting is encountered
# in the future, the reasonable thing to do is likely for the top level SsdFlatParameter to
# absorb the nested one and keep the result flat, free from hierarchy.
if any(isinstance(t, SsdFlatParameter) for t in tensors):
raise ValueError("Nesting SsdFlatParameter is not supported")
requires_grad = tensors[0].requires_grad
dtype = tensors[0].dtype
device = tensors[0].device
for t in tensors:
if t.requires_grad != requires_grad:
raise RuntimeError("Not all tensors have identical requires_grad option")
if t.dtype != dtype:
raise RuntimeError("Not all tensors have identical dtype option")
if t.device != device:
raise RuntimeError("Not all tensors have identical device option")
handle = cls(
shapes=[t.size() for t in tensors],
dtype=tensors[0].dtype,
requires_grad=tensors[0].requires_grad,
device=device,
)
handle.set_file_params(filename, offset)
if direct_to_file:
assert filename != ""
offset = offset
for t in tensors:
write(t, handle.filename, offset)
offset += t.numel() * t.element_size()
handle.storage_state = StorageState.ON_DISK
else:
tensor = torch.cat(
[t.reshape(-1) if isinstance(t, torch.nn.Parameter) else t.reshape(-1) for t in tensors],
0,
).detach()
tensor.requires_grad_()
handle.point_to_tensor(tensor)
return handle
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore
func_name = func.overloadpacket.__name__
r = super(SsdFlatParameter, cls).__torch_dispatch__(func, types, args, kwargs) # type: ignore
if func_name.startswith("split"):
assert isinstance(args[0], SsdFlatParameter)
parent = args[0]
return [SsdFlatParameterView(parent, t, idx) for idx, t in enumerate(r)]
else:
return r
# need to subclass these methods to support Views
def point_to_tensor(self, tensor: torch.Tensor) -> None:
super(SsdFlatParameter, self).point_to_tensor(tensor)
self._refresh_views()
def point_to_file(self, filename: str, offset: int) -> None:
super(SsdFlatParameter, self).point_to_file(filename, offset)
self._invalidate_views()
def to_tensor(self) -> torch.Tensor:
call_refresh_views = False
if self.tensor is None:
call_refresh_views = True
result = super(SsdFlatParameter, self).to_tensor()
if call_refresh_views:
self._refresh_views()
return result
def to_file(self, permit_when_tensor_none: bool = False, release_tensor_after_write: bool = True) -> None:
super(SsdFlatParameter, self).to_file(permit_when_tensor_none, release_tensor_after_write)
self._invalidate_views()
@classmethod
def __unpickle_SFP__(
cls: Type[SsdFlatParameter],
shapes: Sequence[torch.Size],
dtype: torch.dtype,
requires_grad: bool,
filename: str,
) -> SsdFlatParameter:
result = cls(shapes, dtype, requires_grad)
result.point_to_file(filename, 0)
result._unpickle_f = io.open(result.filename, "wb")
return result
def __reduce_ex__(self, proto: int) -> Tuple[Any, Any, Any, Any]:
byte_iter = None
filename = self.filename
if self.override_directory_path is not None:
head, tail = os.path.split(self.filename)
filename = os.path.join(self.override_directory_path, tail)
if self.is_available():
byte_iter = iter(TensorChunkingIterator(self.tensor))
else:
byte_iter = iter(
FileChunkingIterator(self.filename, expected_size_bytes=self.numel() * self.element_size())
)
return (
self.__unpickle_SFP__, # Callable
# Args to the callable above
(self._param_shapes, self._dtype, self.requires_grad, filename),
None,
byte_iter,
)
class SsdFlatParameterView(torch.Tensor):
"""
Represents a view into an SsdFlatParameter. It is needed due to FSDP's usage of flattening parameters.
"""
def __new__(
cls: Type[SsdFlatParameterView], parent: SsdFlatParameter, tensor: torch.Tensor, id: int
) -> SsdFlatParameterView:
r = super(SsdFlatParameterView, cls)._make_wrapper_subclass(cls, tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad, device=tensor.device) # type: ignore
return r
def __init__(self: SsdFlatParameterView, parent: SsdFlatParameter, tensor: torch.Tensor, id: int) -> None:
self.parent = parent
self.tensor: Optional[torch.Tensor] = tensor
self.id = id
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore
"""Intercepts all operations performed on this handle object.
Before any operation, the tensor attribute is unwrapped from the handle
and used in the operation. We maintain a refernce to the tensor and its current
versions to track if modifications have been made. If we detect changes to the
tensor, we write it to the file maintained by the Handle.
"""
func_name = func.overloadpacket.__name__
ssd_tensor_handles = []
def unwrap(e: Any) -> torch.Tensor:
if isinstance(e, SsdFlatParameterView):
if not e.parent.is_available():
e.parent.to_tensor()
# first condition is to take care of the case where we are first constructing e.parent.views as a list comprehension which hasn't
# completed yet
if len(e.parent.views) != 0 and e is not e.parent.views[e.id]:
raise RuntimeError(
"This view should no longer be used as the parent object has had it's .data overwritten (e.parent.views[e.id])!!!"
)
# e.parent will ensure that e.tensor is valid and points to tensor view
t = e.tensor
ssd_tensor_handles.append(e)
return t
else:
return e
r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
for e in ssd_tensor_handles:
inplace_is_this_tensor = (
(func_name.endswith("_") and not func_name.endswith("__")) or func_name.startswith("__i")
) and e is args[0]
out_is_this_tensor = False if "out" not in kwargs else e is kwargs["out"]
if inplace_is_this_tensor or out_is_this_tensor:
e.parent.mark_dirty()
if func_name.startswith("view"):
assert isinstance(args[0], SsdFlatParameterView)
flat_view = args[0]
return SsdFlatParameterView(flat_view.parent, r, flat_view.id)
return r
# ###################################
# ### BEGIN OVERRIDE_PROPERTY FNs ###
# ###################################
# This code is taken mostly from pytorch core parameterization
# pytorch/torch/nn/utils/parametrize.py
def _inject_new_class(module: torch.nn.Module) -> None:
r"""Sets up a module to be parametrized.
This works by substituting the class of the module by a class
that extends it to be able to inject a property
Args:
module (nn.Module): module into which to inject the property
"""
cls = module.__class__
def getstate(self): # type: ignore
raise RuntimeError(
"Serialization of parametrized modules is only "
"supported through state_dict(). See:\n"
"https://pytorch.org/tutorials/beginner/saving_loading_models.html"
"#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training"
)
param_cls = type(
f"Parametrized{cls.__name__}",
(cls,),
{
"__getstate__": getstate,
},
)
module.__class__ = param_cls
module.override_properties: Dict[str, Callable[[], torch.Tensor]] = {} # type: ignore
# setattr(module, "override_properties", {})
def _inject_property(module: torch.nn.Module, property_name: str) -> None:
r"""Injects a property into module[property_name].
It assumes that the class in the module has already been modified from its
original one using _inject_new_class and that the tensor under :attr:`property_name`
has already been moved out
Args:
module (nn.Module): module into which to inject the property
property_name (str): name of the name of the property to create
"""
def get_parametrized(self: torch.nn.Module) -> torch.Tensor:
prop: Callable[[], torch.Tensor] = self.override_properties[property_name] # type: ignore
# If caching is not active, this function just evaluates the parameterization
return prop()
def set_original(self: torch.nn.Module, value: Callable[[], torch.Tensor]) -> None:
self.override_properties[property_name] = value # type: ignore
def del_fn(self: torch.nn.Module) -> None:
_remove_property(self, property_name)
setattr(module.__class__, property_name, property(get_parametrized, set_original, del_fn))
def _register_property(module: torch.nn.Module, property_name: str, property_value: Callable[[], torch.Tensor]) -> None:
has_injected_class = hasattr(module, "override_properties")
if not has_injected_class:
_inject_new_class(module)
if hasattr(module, property_name):
delattr(module, property_name)
module.override_properties[property_name] = property_value # type: ignore
_inject_property(module, property_name)
def _remove_property(module: torch.nn.Module, property_name: str, new_property_value: Optional[Any] = None) -> None:
delattr(module.__class__, property_name)
del module.override_properties[property_name] # type: ignore
# Roll back the parametrized class if no other buffer or parameter
# is currently parametrized in this class
if len(module.override_properties) == 0: # type: ignore
delattr(module, "override_properties")
# Restore class
orig_cls = module.__class__.__bases__[0]
module.__class__ = orig_cls
if new_property_value is not None:
setattr(module.__class__, property_name, new_property_value)
# #################################
# ### END OVERRIDE_PROPERTY FNs ###
# #################################
class SsdFlatParameterViewProperty:
"""
Allows for a mutable view to replace a layer's trainable parameters.
This is needed since FSDP is changing .data under the covers,
SsdFlatParameter cannot just rely on this since each view (of type SsdFlatParameterView) has
an internal representation. So every time we access a view, we need to
make sure we get the up-to-date version, and not the original version
when flattening the parameters.
"""
def __init__(self, parent: SsdFlatParameter, view_id: int) -> None:
super().__init__()
self.parent = parent
self.view_id = view_id
def __call__(self) -> SsdFlatParameterView:
return self.parent.views[self.view_id]
class SsdFlatParameterViewParameterization(torch.nn.Module):
def __init__(self, parent: SsdFlatParameter, view_id: int) -> None:
super().__init__()
self.parent = parent
self.view_id = view_id
def forward(self, *args: Any, **kwargs: Any) -> SsdFlatParameterView:
return self.parent.views[self.view_id]
class DisableMemoizationPicklerModule:
@classmethod
def Pickler(cls, data_buf: io.BytesIO, protocol: int) -> pickle.Pickler:
p = pickle.Pickler(data_buf, protocol)
p.fast = True
return p
@classmethod
def dump(cls, obj: Any, f: io.BytesIO, protocol: int) -> None:
pickle.dump(obj, f, protocol)
class TensorChunkingIterator:
"""
chunk_size_bytes determines how large each chunk that we break the tensor
into. It is important to consider limiting the size because by when
python unpickles an object, by default it will read up to 1000 list
elements at a time. So memory usage while unpickling will be on the
order of O(min(file_size, 1000 * chunk_size_bytes)).
"""
def __init__(self, tensor: torch.Tensor, chunk_size_bytes: int = DEFAULT_CHUNK_SIZE) -> None:
self.tensor = tensor
self.chunk_size_bytes = chunk_size_bytes
def __iter__(self) -> Iterator[bytes]:
self.num_chunks = _get_num_chunks(self.tensor, self.chunk_size_bytes)
self.num_chunks_read = 0
return self
def __next__(self) -> bytes:
if self.num_chunks_read >= self.num_chunks:
raise StopIteration
next_chunk = _tensor_to_bytes_chunks(
self.tensor, chunk_idx=self.num_chunks_read, chunk_size_bytes=self.chunk_size_bytes
)
self.num_chunks_read += 1
return next_chunk
class FileChunkingIterator:
"""
chunk_size_bytes determines how large each chunk that we break the file
into. It is important to consider limiting the size because by when
python unpickles an object, by default it will read up to 1000 list
elements at a time. So memory usage while unpickling will be on the
order of O(min(file_size, 1000 * chunk_size_bytes)).
"""
def __init__(
self, filename: str, expected_size_bytes: int = -1, chunk_size_bytes: int = DEFAULT_CHUNK_SIZE
) -> None:
self.filename = filename
self.file: Optional[Union[BinaryIO, IO[bytes]]] = None
self.chunk_size_bytes = chunk_size_bytes
self.expected_size_bytes = expected_size_bytes
def __iter__(self) -> Iterator[bytes]:
if self.expected_size_bytes != -1:
file_size = os.stat(self.filename).st_size
assert (
file_size == self.expected_size_bytes
), f"FileChunkingIterator Failed, expecting file to be of size: {self.expected_size_bytes} but got {file_size}"
self.file = io.open(self.filename, "rb", buffering=0)
self.num_chunks_read = 0
return self
def __next__(self) -> bytes:
assert self.file
next_chunk = self.file.read(self.chunk_size_bytes)
if len(next_chunk) == 0:
raise StopIteration
self.num_chunks_read += 1
return next_chunk
torch_saver = TorchSaver()
......@@ -9,7 +9,6 @@ import torch.distributed as dist
from .fully_sharded_data_parallel import (
FullyShardedDataParallel,
OffloadConfig,
TrainingState,
auto_wrap_bn,
get_fsdp_instances,
......
......@@ -5,13 +5,11 @@
import contextlib
import copy
from dataclasses import dataclass
from enum import Enum, auto
import functools
import logging
from math import inf
import os
import tempfile
import time
import traceback
import typing
......@@ -69,15 +67,6 @@ if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
else:
enable_nccl_base_collectives = True
try:
import fairscale.experimental.nn.ssd_offload as ssd_offload
import_ssd_offload = True
except ImportError:
# The latest nightly PyTorch version required
import_ssd_offload = False
pass
class TrainingState(Enum):
"""
......@@ -107,19 +96,6 @@ class TrainingState(Enum):
SUMMON_FULL_PARAMS = auto()
# Data classes containing FSDP parameter constructs
# Offload config for specifying SSD options (initially at least)
@dataclass
class OffloadConfig:
"""Class for specifying all arguments related to offloading parameters."""
# Offload type: currently only supports: "ssd_offload"
offload_type: Optional[str] = None
# Path to the directory for storing parameters offloaded to disk.
dir: Optional[str] = None
class FullyShardedDataParallel(nn.Module):
"""
A wrapper for sharding Module parameters across data parallel workers. This
......@@ -302,10 +278,6 @@ class FullyShardedDataParallel(nn.Module):
cpu_offload (bool, Optional):
if ``True``, offload params to CPU. Note: This arg will be deprecated in favor of
*``move_params_to_cpu``* in an upcoming release.
offload_config (OffloadConfig):
The `OffloadConfig` object is used to specify the type of offload (i.e SSD, CPU) and
other required knobs when offloading parameters from GPU. Currently the OffloadConfig
only supports specifying SSD offload as an option. Note: This is an experimental feature.
state_dict_on_rank_0_only (bool):
When set to ``True``, ``model.state_dict()`` will only returns full state dict on
rank 0 and return empty dict non-rank 0, which allow FullyShardedDataParallel to
......@@ -342,7 +314,6 @@ class FullyShardedDataParallel(nn.Module):
force_input_to_fp32: bool = False,
verbose: bool = False,
cpu_offload: bool = False,
offload_config: Optional[OffloadConfig] = None,
state_dict_on_rank_0_only: bool = False,
gradient_predivide_factor: Optional[float] = None,
allow_reset_parameters: bool = False,
......@@ -414,12 +385,6 @@ class FullyShardedDataParallel(nn.Module):
self.force_input_to_fp32 = force_input_to_fp32
self.verbose = verbose
self.state_dict_on_rank_0_only = state_dict_on_rank_0_only
# Experimental feature for now. Use at your own risk.
self.ssd_offload = True if offload_config and offload_config.offload_type == "ssd_offload" else False
if self.ssd_offload and not import_ssd_offload:
raise ImportError(
f"Trying to enable ssd_offload when it was not successfully imported (likely due to old torch version, current = {torch.__version__})"
)
self.gradient_predivide_factor: float = gradient_predivide_factor or self._get_gradient_predivide_factor(
self.world_size
......@@ -433,9 +398,6 @@ class FullyShardedDataParallel(nn.Module):
if self.fp32_reduce_scatter and not self.mixed_precision:
raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
if self.ssd_offload and not self.flatten_parameters:
raise ValueError(f"offload type: '{offload_config.offload_type}' requires flatten_parameters=True")
# skip validation if the process group was created above
if process_group:
validate_process_group(self.compute_device, self.process_group)
......@@ -456,16 +418,7 @@ class FullyShardedDataParallel(nn.Module):
self._has_params = len(params) > 0
self._has_shared_params = False
# TODO(anj): Should we conditionally do this only if we have params?
# TODO(anj): Figure out if we can allocate the buffer during sharding.
self.buffer_size = sum(p.numel() for p in params)
self.ssd_directory = tempfile.gettempdir()
if self.ssd_offload:
assert import_ssd_offload, "We need to import ssd_offload.py to enable the `ssd_offload` feature."
if offload_config and offload_config.dir:
self.ssd_directory = offload_config.dir
self.move_grads_to_cpu = True
self.move_params_to_cpu = True
# For now, it is either all flatten or none flatten. This will be extended to
# multiple flatten groups in my next PR.
......@@ -478,9 +431,7 @@ class FullyShardedDataParallel(nn.Module):
param_name_groups = [param_names]
del param_names
self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(
module, param_list=to_be_flatten_params, ssd_offload=self.ssd_offload, ssd_directory=self.ssd_directory
)
self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(module, param_list=to_be_flatten_params)
del module # free original module in case it helps garbage collection
# Now, in this FSDP wrapper class, we keep a list of to-be-flatten and not-to-be-flatten
......@@ -531,8 +482,6 @@ class FullyShardedDataParallel(nn.Module):
# Flag to indicate whether state_dict() should automatically summon the
# full params. This defaults to True, but may be set to False if the
# user explicitly requests the local state dict via local_state_dict().
# TODO(anj): This should by default be set to False for ssd_offload=True
# unless we are in the summon_full_params context.
self._return_full_state_dict = True
init_end = time.time()
......@@ -544,11 +493,6 @@ class FullyShardedDataParallel(nn.Module):
# This is reset at the end of the backward pass.
self._pre_backward_hook_has_run = False
# Free all params at the end of initialization.
if self.ssd_offload:
for m in get_fsdp_instances(self):
m._free_ssd_offload()
def _get_gradient_predivide_factor(self, world_size: int) -> float:
factor: int = 1
while world_size % factor == 0 and world_size / factor > factor:
......@@ -785,10 +729,9 @@ class FullyShardedDataParallel(nn.Module):
p._orig_size = p.data.size()
if not p._is_sharded:
if not self.ssd_offload:
p._is_sharded = False
self.numel_padded_per_param.append(0)
continue
p._is_sharded = False
self.numel_padded_per_param.append(0)
continue
p._is_sharded = True
# TODO (Min): broadcast from rank 0 to avoid each rank need to init with the same seed?
......@@ -797,11 +740,7 @@ class FullyShardedDataParallel(nn.Module):
p.data, num_padded = self._get_shard(p.data)
self.numel_padded_per_param.append(num_padded)
if self.ssd_offload:
assert isinstance(p, ssd_offload.SsdParameter)
p.to_file()
else:
free_storage_(orig_data)
free_storage_(orig_data)
assert len(self.numel_padded_per_param) == len(self.params)
......@@ -1014,21 +953,11 @@ class FullyShardedDataParallel(nn.Module):
backup = self._return_full_state_dict
self._return_full_state_dict = False
if self.ssd_offload:
# Move params from disk to memory before returning the local state dict.
self._move_params_to_memory()
try:
yield
finally:
self._return_full_state_dict = backup
def _move_params_to_memory(self) -> None:
"""Move params from disk to CPU."""
for p in self.params:
assert isinstance(p, ssd_offload.SsdParameter)
p.to_tensor()
def _load_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
) -> NamedTuple:
......@@ -1276,22 +1205,11 @@ class FullyShardedDataParallel(nn.Module):
# Copy any changes made to the full params back into
# the corresponding local shards.
local_shard, _ = self._get_shard(full_tensor)
if self.ssd_offload:
assert isinstance(p, ssd_offload.SsdParameter)
self._ssd_offload_reset_param_device(p)
p.point_to_tensor(local_shard.view_as(p._fp32_shard).cpu())
else:
p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
if safe_to_free:
free_storage_(full_tensor)
self.has_full_params = False
if self.ssd_offload:
# Store tensors in the SSD buffer and free param storage.
for p in self.params:
assert isinstance(p, ssd_offload.SsdParameter)
p.to_file()
else:
self._use_fp32_param_shard()
self._use_fp32_param_shard()
self.training_state = TrainingState.IDLE
def _reset_lazy_init(self) -> None:
......@@ -1366,11 +1284,6 @@ class FullyShardedDataParallel(nn.Module):
return
# A single shard of the parameters in full precision.
# TODO(another-pjohnson) - I believe this will cause memory leakage with ssd
# p.data returns a pointer to a handle, and that handle has it's
# ref count incremented by p._fp32_shard. So this tensor will
# never be freed even if we do p.to_disk(). investigate after
# PR #887 is merged
p._fp32_shard = p.data
if self.mixed_precision:
......@@ -1378,14 +1291,11 @@ class FullyShardedDataParallel(nn.Module):
if self.move_params_to_cpu:
assert p._fp32_shard.device == torch.device("cpu"), self
# We don't pin memory when using ssd_offload since that results in OOM when
# the memory requirements of a model are larger than host memory.
if not self.ssd_offload:
# If we plan to keep the FP32 parameters on CPU, then pinning
# memory allows us to later use non-blocking transfers when moving
# the FP32 param shard to compute_device.
p._fp32_shard = p._fp32_shard.pin_memory()
p.data = p._fp32_shard
# If we plan to keep the FP32 parameters on CPU, then pinning
# memory allows us to later use non-blocking transfers when moving
# the FP32 param shard to compute_device.
p._fp32_shard = p._fp32_shard.pin_memory()
p.data = p._fp32_shard
if self.move_params_to_cpu or self.mixed_precision:
......@@ -1423,16 +1333,7 @@ class FullyShardedDataParallel(nn.Module):
# pass. In this case, it's important to pre-allocate the CPU grad
# shard in pinned memory so that we can do a non-blocking transfer.
# This is only needed during training and not evaluation.
if self.ssd_offload:
assert isinstance(p, ssd_offload.SsdParameter)
# Gradients also need to be offloaded to SSD otherwise it can result in
# OOMs when the memory requirements of a model are larger than host memory.
p._cpu_grad = ssd_offload.SsdTensorHandle.from_tensor(torch.zeros_like(p.data, device="cpu"))
p._cpu_grad.allow_unsafe_changes = True
p._cpu_grad.set_file_params(p.filename + "_grad", 0)
p._cpu_grad.to_file()
else:
p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory()
p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory()
def _set_is_root(self) -> None:
"""If ``True``, implies that no other :class:`FullyShardedDataParallel`
......@@ -1576,17 +1477,8 @@ class FullyShardedDataParallel(nn.Module):
if self.clear_autocast_cache:
torch.clear_autocast_cache()
self._free_ssd_offload()
return outputs
@torch.no_grad()
def _free_ssd_offload(self) -> None:
if self.ssd_offload:
for p in self.params:
assert isinstance(p, ssd_offload.SsdParameter)
p.to_file(permit_when_tensor_none=True)
def _register_pre_backward_hooks(self, outputs: Any) -> Any:
"""Register pre-backward hook to run before the wrapped module's
backward. Hooks should be attached to all outputs from the forward.
......@@ -1990,7 +1882,6 @@ class FullyShardedDataParallel(nn.Module):
# Update root and nested FSDP's hooks and flags.
for m in get_fsdp_instances(self):
_finalize_parameters(m)
m._free_ssd_offload()
m._pre_backward_hook_has_run = False
if any(p.requires_grad for p in m.parameters()):
# Check if the module has params and if any of them has
......@@ -2071,15 +1962,6 @@ class FullyShardedDataParallel(nn.Module):
# Trim any padding and reshape to match original size.
p.data = p.data[: p._orig_size.numel()].view(p._orig_size)
if self.ssd_offload:
for p in self.params:
assert isinstance(p, ssd_offload.SsdParameter)
if not p.is_available():
self._ssd_offload_reset_param_device(p)
p.to_tensor()
self.has_full_params = False
if self._has_shared_params:
# self.has_full_params flag can be out of sync if a shared param is
# sharded by another FSDP instance. An example is that in eval case
......@@ -2366,25 +2248,13 @@ class FullyShardedDataParallel(nn.Module):
return consolidated_weights
@torch.no_grad()
def _ssd_offload_reset_param_device(self, param: Parameter) -> None:
assert isinstance(param, ssd_offload.SsdParameter)
if param.device != torch.device("cpu"):
param.data = param._fp32_shard
param.tensor = None
@torch.no_grad()
def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
"""Use FP32 shard for a list of params."""
if params is None:
params = self.params
for p in params:
if import_ssd_offload and self.ssd_offload:
assert isinstance(p, ssd_offload.SsdParameter)
self._ssd_offload_reset_param_device(p)
p.to_tensor()
else:
p.data = p._fp32_shard
p.data = p._fp32_shard
@torch.no_grad()
def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = None) -> None:
......@@ -2395,14 +2265,11 @@ class FullyShardedDataParallel(nn.Module):
for p in params:
assert p._fp16_shard is not None
alloc_storage_(p._fp16_shard, size=p._fp32_shard.size())
if self.ssd_offload:
p._fp16_shard.copy_(p.to(p._fp16_shard.device, non_blocking=True))
else:
p._fp16_shard.copy_(
# If move_params_to_cpu is True, this will be non-blocking
# because _fp32_shard is pinned, otherwise it's a no-op.
p._fp32_shard.to(p._fp16_shard.device, non_blocking=True)
)
p._fp16_shard.copy_(
# If move_params_to_cpu is True, this will be non-blocking
# because _fp32_shard is pinned, otherwise it's a no-op.
p._fp32_shard.to(p._fp16_shard.device, non_blocking=True)
)
p.data = p._fp16_shard
torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"])
......
......@@ -8,7 +8,6 @@
from contextlib import contextmanager
from itertools import chain
import tempfile
import typing
from typing import (
TYPE_CHECKING,
......@@ -31,19 +30,6 @@ import torch
from torch import Tensor
import torch.nn as nn
try:
from fairscale.experimental.nn.ssd_offload import (
SsdFlatParameter,
SsdFlatParameterView,
SsdFlatParameterViewProperty,
_register_property,
)
import_ssd_offload = True
except ImportError:
import_ssd_offload = False
pass
from fairscale.internal.state_dict import replace_by_prefix_
if TYPE_CHECKING:
......@@ -169,15 +155,8 @@ class FlattenParamsWrapper(nn.Module):
module: nn.Module,
param_list: ParamGroups = None,
flat_param_names: Optional[List[str]] = None,
ssd_offload: bool = False,
ssd_directory: str = "",
):
super().__init__()
if ssd_offload and not import_ssd_offload:
raise ImportError(
f"Trying to enable ssd_offload when it was not successfully imported (likely due to old torch version, current = {torch.__version__})"
)
self.ssd_offload = ssd_offload
self._fpw_module = module
self.is_flattened = False
......@@ -239,14 +218,7 @@ class FlattenParamsWrapper(nn.Module):
# Init all flat_params.
for new_p_set in self._param_sets:
params, param_infos, shared_param_infos = self._init_flatten_params(new_p_set)
if ssd_offload:
assert ssd_directory != ""
(handle, fname) = tempfile.mkstemp(dir=ssd_directory, suffix="ssd_buf_param")
flat_param = SsdFlatParameter.from_tensors(tensors=params)
flat_param.allow_unsafe_changes = True
flat_param.set_file_params(fname, 0)
else:
flat_param = FlatParameter(params, params[0].requires_grad)
flat_param = FlatParameter(params, params[0].requires_grad)
flat_param._param_infos = param_infos
flat_param._shared_param_infos = shared_param_infos
self.flat_params.append(flat_param)
......@@ -393,13 +365,8 @@ class FlattenParamsWrapper(nn.Module):
ps = self.get_param_views()
param_views = []
for (_, m, n), p in zip(self._param_infos, ps):
if self.ssd_offload:
assert isinstance(p, SsdFlatParameterView)
_register_property(m, n, SsdFlatParameterViewProperty(p.parent, p.id))
else:
setattr(m, n, p) # This will set as plain attr
param_views.append(p)
setattr(m, n, p) # This will set as plain attr
param_views.append(p)
# Save param views for easy access if anyone still wants to access
# parameters of the module.
......
......@@ -6,7 +6,6 @@ tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
tests/experimental/nn/test_offload.py
tests/experimental/nn/test_auto_shard.py
tests/experimental/optim/test_dynamic_loss_scaler.py
tests/experimental/nn/test_ssd_offload.py
tests/nn/data_parallel/test_fsdp_shared_weights_mevo.py
tests/nn/data_parallel/test_fsdp_shared_weights.py
tests/nn/data_parallel/test_fsdp_pre_backward_hook.py
......@@ -50,5 +49,4 @@ tests/nn/pipe/test_dependency.py
tests/nn/pipe/test_stream.py
tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_top2gating.py
tests/nn/data_parallel/test_fsdp_offload.py
tests/nn/data_parallel/test_fsdp_fwd_fwd_bwd_bwd.py
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""
Testing SsdFlatParameter and SsdTensorHandle modules.
"""
import filecmp
import functools
import os
import tempfile
import numpy as np
import pytest
import torch
pytestmark = pytest.mark.skip(reason="ssd offload to be removed to simplify the code")
try:
import fairscale.experimental.nn.ssd_offload as so
except ImportError as ie:
# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
pytestmark = pytest.mark.skipif(True, reason=ie.msg)
pass
def _init():
torch.manual_seed(0)
np.random.seed(0)
def test_write_read():
_init()
with tempfile.NamedTemporaryFile() as f:
ref_tensor = torch.rand(128, dtype=torch.float32)
test_tensor = torch.zeros_like(ref_tensor)
assert not torch.equal(ref_tensor, test_tensor)
so.write(ref_tensor, f.name)
so.read(test_tensor, f.name)
assert torch.equal(ref_tensor, test_tensor)
def test_ssd_handle_dispatch_fwd():
_init()
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn(128)
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
ssd_handle.set_file_params(f.name, 0)
ssd_handle.to_file(release_tensor_after_write=True)
assert torch.equal(ssd_handle.to_tensor(), orig_tensor)
# This should trigger the torch_dispatch code and write
# back the results to the file
ssd_handle.add_(1)
plus1_tensor = orig_tensor.add(1)
assert torch.equal(ssd_handle.to_tensor(), plus1_tensor)
def test_ssd_handle_dispatch_bwd():
_init()
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4), requires_grad=True)
orig_copy = orig_tensor.clone().detach().requires_grad_(True)
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
ssd_handle.set_file_params(f.name, 0)
ssd_handle.to_file(release_tensor_after_write=True)
assert torch.equal(ssd_handle.to_tensor(), orig_tensor)
y1 = ssd_handle + 1
y2 = orig_copy + 1
y1.sum().backward()
y2.sum().backward()
assert torch.equal(ssd_handle.grad, orig_copy.grad)
@pytest.mark.skip("broken at head")
def test_ssd_handle_dispatch_bwd_hook():
_init()
def post_backward_hook(name, grad):
print(f"BACKWARD HOOK for tensor {name} CALLED")
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4), requires_grad=True)
orig_copy = orig_tensor.clone().detach().requires_grad_(True)
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
ssd_handle.set_file_params(f.name, 0)
ssd_handle.to_file(release_tensor_after_write=True)
one = torch.ones(1, requires_grad=True).cuda()
orig_copy = ssd_handle.data
cuda_copy = ssd_handle.to("cuda").detach().requires_grad_(True)
ssd_handle.data = cuda_copy
ssd_handle.register_hook(functools.partial(post_backward_hook, "ssd_handle"))
one.register_hook(functools.partial(post_backward_hook, "one"))
y1 = ssd_handle + one
y1.sum().backward()
def test_ssd_handle_train_simple():
_init()
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4), requires_grad=True)
with torch.no_grad():
orig_copy = torch.empty_like(orig_tensor)
orig_copy.copy_(orig_tensor)
orig_copy.requires_grad = True
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
ssd_handle.flush_on_dirty = False
ssd_handle.set_file_params(f.name, 0)
ssd_handle.to_file(release_tensor_after_write=True)
assert torch.equal(ssd_handle.to_tensor(), orig_tensor)
optimizer_ssd = torch.optim.SGD([ssd_handle], lr=0.1)
optimizer_orig = torch.optim.SGD([orig_copy], lr=0.1)
y1 = ssd_handle + 1
optimizer_ssd.zero_grad()
y1.sum().backward()
assert ssd_handle.storage_state is so.StorageState.ON_CPU_CLEAN
optimizer_ssd.step()
assert ssd_handle.storage_state is so.StorageState.ON_CPU_DIRTY
y2 = orig_copy + 1
optimizer_orig.zero_grad()
y2.sum().backward()
optimizer_orig.step()
assert torch.equal(ssd_handle.to_tensor(), orig_copy)
def test_torch_save_load_ssd_flat_param_on_disk():
_init()
orig_file = tempfile.NamedTemporaryFile(prefix="tensor")
checkpoint_file = tempfile.NamedTemporaryFile(prefix="checkpoint", suffix=".pt")
checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir")
# TENSOR_SHAPE = (1024, 1024, 2048)
# use smaller shape for unit tests
TENSOR_SHAPE = (1024, 321)
ref_tensors = [torch.rand(TENSOR_SHAPE, dtype=torch.float32) for i in range(4)]
ssd_handle = so.SsdFlatParameter.from_tensors(ref_tensors, False)
ssd_handle.set_file_params(orig_file.name, 0)
ssd_handle.to_file()
ref_tensors = []
# after deleting ref_tensor, memory usage should be very low
# For save it shouldn't be more than 10x so.DEFAULT_CHUNK_SIZE
with so.CheckpointPathContextManager(override_path=checkpoint_load_directory.name):
so.torch_saver.save(ssd_handle, checkpoint_file.name)
# below line saves file to checkpoint_load_directory/orig_file.name
# Memory usage here should be O(1000 * so.DEFAULT_CHUNK_SIZE)
# 1000x because that's how many elements the python unpickler
# will buffer before passing to the SsdTensor
test_ssd_handle = torch.load(checkpoint_file)
head, tail = os.path.split(orig_file.name)
assert filecmp.cmp(orig_file.name, os.path.join(checkpoint_load_directory.name, tail), shallow=False)
def test_torch_save_load_ssd_flat_param_on_mem():
_init()
orig_file = tempfile.NamedTemporaryFile(prefix="tensor")
checkpoint_file = tempfile.NamedTemporaryFile(prefix="checkpoint", suffix=".pt")
checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir")
# TENSOR_SHAPE = (1024, 1024, 2048)
# use smaller shape for unit tests
TENSOR_SHAPE = (1024, 321)
ref_tensors = [torch.rand(TENSOR_SHAPE, dtype=torch.float32) for i in range(4)]
ssd_handle = so.SsdFlatParameter.from_tensors(ref_tensors, False)
ssd_handle.set_file_params(orig_file.name, 0)
ref_tensors = []
# after deleting ref_tensor, memory usage should be very low
# For save it shouldn't be more than 10x so.DEFAULT_CHUNK_SIZE
with so.CheckpointPathContextManager(override_path=checkpoint_load_directory.name):
so.torch_saver.save(ssd_handle, checkpoint_file.name)
# below line saves file to checkpoint_load_directory/orig_file.name
# Memory usage here should be O(1000 * so.DEFAULT_CHUNK_SIZE)
# 1000x because that's how many elements the python unpickler
# will buffer before passing to the SsdTensor
test_ssd_handle = torch.load(checkpoint_file)
assert torch.equal(ssd_handle, test_ssd_handle)
def test_ssd_param_train_simple():
_init()
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4))
with torch.no_grad():
orig_copy = torch.empty_like(orig_tensor)
orig_copy.copy_(orig_tensor)
param = torch.nn.Parameter(orig_copy)
ssd_param = so.SsdParameter(orig_tensor.shape, orig_tensor.dtype)
ssd_param.point_to_tensor(orig_copy)
ssd_param.flush_on_dirty = False
ssd_param.set_file_params(f.name, 0)
ssd_param.to_file(release_tensor_after_write=True)
assert torch.equal(ssd_param.to_tensor(), orig_tensor)
optimizer_ssd = torch.optim.SGD([ssd_param], lr=0.1)
optimizer_orig = torch.optim.SGD([param], lr=0.1)
y1 = ssd_param + 1
optimizer_ssd.zero_grad()
y1.sum().backward()
# Test to see if Dirty is being calculated correctly when optimizer modifies
# ssd_param
assert ssd_param.storage_state is so.StorageState.ON_CPU_CLEAN
optimizer_ssd.step()
assert ssd_param.storage_state is so.StorageState.ON_CPU_DIRTY
y2 = param + 1
optimizer_orig.zero_grad()
y2.sum().backward()
optimizer_orig.step()
assert torch.equal(ssd_param.to_tensor(), param)
def test_ssd_flat_parameter_basic():
_init()
with tempfile.NamedTemporaryFile() as f:
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refc_param = torch.nn.Parameter(torch.rand(128, dtype=torch.float32))
ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], direct_to_file=False)
ssd_flat_param.set_file_params(f.name, 0)
param_views = list(ssd_flat_param.get_param_views())
assert refa_param.shape == param_views[0].shape
assert refb_param.shape == param_views[1].shape
assert refc_param.shape == param_views[2].shape
assert torch.equal(refa_param, param_views[0])
assert torch.equal(refb_param, param_views[1])
assert torch.equal(refc_param, param_views[2])
ssd_flat_param.to_file()
assert not ssd_flat_param.is_available()
first_value = param_views[0][0][0].item()
assert ssd_flat_param.is_available()
assert first_value == refa_param[0][0].item()
def test_ssd_flat_parameter_view_modify():
_init()
with tempfile.NamedTemporaryFile() as f:
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=False)
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=False)
refc_param = torch.nn.Parameter(torch.rand(128, dtype=torch.float32), requires_grad=False)
ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], direct_to_file=False)
ssd_flat_param.set_file_params(f.name, 0)
ssd_flat_param.flush_on_dirty = False
param_views = list(ssd_flat_param.get_param_views())
assert ssd_flat_param.storage_state == so.StorageState.ON_CPU_DIRTY
ssd_flat_param.to_file()
assert ssd_flat_param.storage_state == so.StorageState.ON_DISK
assert param_views[0].tensor is None
param_views[0] += 0.1
assert ssd_flat_param.storage_state == so.StorageState.ON_CPU_DIRTY
@pytest.mark.skip("broken at head")
def test_ssd_flat_parameter_view_bwd():
_init()
hooks_called = []
def post_backward_hook(name, hooks_called, *grads):
print(f"BACKWARD HOOK for tensor {name} CALLED")
hooks_called.append(name)
with tempfile.NamedTemporaryFile() as f:
refa_param = (
torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=True)
.to("cpu")
.detach()
.requires_grad_()
)
refb_param = (
torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=True)
.to("cpu")
.detach()
.requires_grad_()
)
refc_param = (
torch.nn.Parameter(torch.rand(128, dtype=torch.float32), requires_grad=True)
.to("cpu")
.detach()
.requires_grad_()
)
ssd_flat_param = so.SsdFlatParameter.from_tensors(
[refa_param, refb_param, refc_param], direct_to_file=True, filename=f.name, offset=0
)
orig_copy = ssd_flat_param.data
cuda_copy = ssd_flat_param.to("cuda").detach().requires_grad_()
cpu_copy = ssd_flat_param.to("cpu").detach().requires_grad_()
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_orig", hooks_called))
ssd_flat_param.data = cuda_copy
one = torch.ones(1, requires_grad=True, device=ssd_flat_param.device)
y1 = ssd_flat_param.views[0] + one
y2 = cuda_copy + 1
# ssd_flat_param.to_file()
# ssd_flat_param.data = orig_copy
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_cuda", hooks_called))
ssd_flat_param.views[0].register_hook(
functools.partial(post_backward_hook, "ssd_flat_param.views[0]", hooks_called)
)
ssd_flat_param.register_hook(functools.partial(post_backward_hook, "ssd_flat_param", hooks_called))
one.register_hook(functools.partial(post_backward_hook, "one", hooks_called))
y1.sum().backward()
y2.sum().backward()
assert "GradAccumulation_cuda" in hooks_called
assert "ssd_flat_param.views[0]" in hooks_called
assert "ssd_flat_param" in hooks_called
assert "one" in hooks_called
@pytest.mark.skip("broken at head")
def test_ssd_flat_parameter_view_bwd_parameterization():
_init()
hooks_called = []
def post_backward_hook(name, hooks_called, *grads):
print(f"BACKWARD HOOK for tensor {name} CALLED")
hooks_called.append(name)
with tempfile.NamedTemporaryFile() as f:
layer1 = torch.nn.Linear(32, 4, bias=False)
layer2 = torch.nn.Linear(32, 4, bias=False)
layer3 = torch.nn.Linear(128, 1, bias=False)
ssd_flat_param = so.SsdFlatParameter.from_tensors(
[layer1.weight, layer2.weight, layer3.weight], direct_to_file=False, filename=f.name, offset=0
)
torch.nn.utils.parametrize.register_parametrization(
layer1, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 0)
)
torch.nn.utils.parametrize.register_parametrization(
layer2, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 1)
)
torch.nn.utils.parametrize.register_parametrization(
layer3, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 2)
)
orig_copy = ssd_flat_param.data
cuda_copy = ssd_flat_param.to("cuda").detach().requires_grad_()
cpu_copy = ssd_flat_param.to("cpu").detach().requires_grad_()
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_orig", hooks_called))
ssd_flat_param.to_file(release_tensor_after_write=False)
ssd_flat_param.data = cuda_copy
one = torch.ones(layer1.weight.shape, requires_grad=True, device=ssd_flat_param.device)
y1 = layer1.forward(one)
y2 = cuda_copy + 1
# ssd_flat_param.to_file()
# ssd_flat_param.data = orig_copy
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_cuda", hooks_called))
ssd_flat_param.views[0].register_hook(
functools.partial(post_backward_hook, "ssd_flat_param.views[0]", hooks_called)
)
ssd_flat_param.register_hook(functools.partial(post_backward_hook, "ssd_flat_param", hooks_called))
one.register_hook(functools.partial(post_backward_hook, "one", hooks_called))
y1.sum().backward()
y2.sum().backward()
assert "GradAccumulation_cuda" in hooks_called
assert "ssd_flat_param.views[0]" in hooks_called
assert "ssd_flat_param" in hooks_called
assert "one" in hooks_called
def test_ssd_flat_parameter_direct_to_file():
_init()
with tempfile.NamedTemporaryFile() as f:
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refc_param = torch.nn.Parameter(torch.rand(128, dtype=torch.float32))
ssd_flat_param = so.SsdFlatParameter.from_tensors(
[refa_param, refb_param, refc_param], direct_to_file=True, filename=f.name, offset=0
)
param_views = list(ssd_flat_param.get_param_views())
assert refa_param.shape == param_views[0].shape
assert refb_param.shape == param_views[1].shape
assert refc_param.shape == param_views[2].shape
assert torch.equal(refa_param, param_views[0])
assert torch.equal(refb_param, param_views[1])
assert torch.equal(refc_param, param_views[2])
ssd_flat_param.to_file()
assert not ssd_flat_param.is_available()
first_value = param_views[0][0][0].item()
assert ssd_flat_param.is_available()
assert first_value == refa_param[0][0].item()
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
import itertools
import sys
import tempfile
import time
import unittest
from parameterized import parameterized
import pytest
import torch
from torch import nn
import torch.distributed
pytestmark = pytest.mark.skip(reason="ssd offload to be removed to simplify the code")
try:
import fairscale.experimental.nn.ssd_offload as so
except ImportError as ie:
# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
pytestmark = pytest.mark.skipif(True, reason=ie.msg)
pass
from fairscale.fair_dev.testing.testing import dist_init, spawn_for_all_world_sizes
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState
# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# All helper functions called by spawn must be either @classmethod, @staticmethod
class DistributedTest(unittest.TestCase):
def setUp(self):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA not available, skipping test")
if sys.platform == "win32":
raise unittest.SkipTest("NCCL doesn't support Windows, skipping test")
if torch.cuda.device_count() < 2:
raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping")
@staticmethod
def _eval_with_config(model, autocast):
model.eval()
model_device = torch.device("cuda")
with torch.cuda.amp.autocast(enabled=autocast):
# Inputs always cuda regardless of move_grads_cpu, or model.device
input = model.module.get_input(torch.device("cuda"))
output = model(*input)
loss = model.module.get_loss(input, output).to(model_device)
assert loss.dtype == torch.float32
if isinstance(model, FullyShardedDataParallel):
model.assert_state(TrainingState.IDLE)
return loss.detach()
@staticmethod
def _eval_for_several_steps(model, num_steps, autocast, lr=0.01, norm_type=None):
model.eval()
# Inputs always cuda regardless of move_grads_cpu, or model.device
input = model.module.get_input(torch.device("cuda"))
for _ in range(num_steps):
with torch.cuda.amp.autocast(enabled=autocast):
output = model(*input)
@classmethod
def _test_identical_outputs_eval(
cls,
model_init_fn,
config,
rank,
group,
num_steps=2,
use_cuda=True,
lr=0.01,
ref_ddp_fn=None,
):
if config.get("mixed_precision", False):
autocast = True
# Force the compute dtype to be torch.float32 so that we get
# identical results as PyTorch DDP when using autocast. Note that
# this will cause the all-gather to happen in FP32, which is slower
# than necessary in most cases.
config["compute_dtype"] = torch.float32
else:
autocast = False
# Establish reference behavior with PyTorch DDP (+ optionally autocast).
model = model_init_fn(group=group, wrapper_config=None).cuda()
if ref_ddp_fn is None:
model = nn.parallel.DistributedDataParallel(
model, device_ids=[rank], output_device=rank, process_group=group
)
else:
model = ref_ddp_fn(model, group)
ref_loss = cls._eval_with_config(model, autocast)
ref_state_dict = model.module.state_dict()
if config.get("cpu_offload", False):
for k in ref_state_dict.keys():
ref_state_dict[k] = ref_state_dict[k].cpu()
# Confirm we get the same behavior using FullyShardedDataParallel.
if config.get("ssd_offload", False):
config["offload_config"] = OffloadConfig(offload_type="ssd_offload")
# ssd offload only supports flatten_params ATM
config["flatten_parameters"] = True
del config["ssd_offload"]
model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
if not model.ssd_offload and not model.move_params_to_cpu:
if use_cuda:
model = model.cuda()
else:
assert next(model.parameters()).device == torch.device("cpu")
shard_loss = cls._eval_with_config(model, autocast)
try:
torch.testing.assert_allclose(ref_loss, shard_loss)
except (AssertionError, RuntimeError) as e:
raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}")
if config.get("flatten_parameters", True):
metadata = model.local_metadata_dict()
assert isinstance(metadata, dict)
keys = ["reshard_after_forward", "mixed_precision", "nested_wrapping"]
CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([True, False], repeat=len(keys))]
def rename_test(testcase_func, param_num, param):
return "%s_%s" % (
testcase_func.__name__,
parameterized.to_safe_name(str(param.args)),
)
class TestSsdMemory(DistributedTest):
def test_memory_benchmark(self):
test_fn = functools.partial(self._test_memory_benchmark, config={})
spawn_and_init(test_fn)
@classmethod
def _test_memory_benchmark(self, rank, group, config):
time_keeper = TimeKeeper()
SIZE = 8 * 8
time_keeper.print_time("START", 1.0)
a = torch.empty(1)
b = a.cuda()
# wait for cuda to fully load
time.sleep(1)
time_keeper.print_time("INIT_CUDA", 1.0)
model = SimpleLinear(group, input_size=SIZE, output_size=SIZE, layers=4)
time_keeper.print_time("CPU_MODEL", 1.0)
with tempfile.TemporaryDirectory() as current_tempdir:
config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
model = FullyShardedDataParallel(model, **config)
time_keeper.print_time("FSDP_MODEL", 1.0)
self._eval_for_several_steps(model, 1, autocast=False)
time_keeper.print_time("EVAL")
class SimpleLinear(nn.Module):
def __init__(self, group, input_size, output_size, layers=1, **unused_kwargs):
super().__init__()
self.rank = group.rank()
self.world_size = group.size()
self.input_size = input_size
self.output_size = output_size
torch.manual_seed(0) # keep everything deterministic
seq_layers = []
for i in range(layers):
seq_layers.append(nn.Linear(input_size, output_size, bias=False))
self.module = nn.Sequential(*seq_layers)
self.bs = 2
def get_input(self, device):
torch.manual_seed(1 + self.rank) # keep everything deterministic
src = torch.rand((self.bs, self.input_size), device=device, dtype=torch.float32)
tgt = torch.rand((self.bs, self.input_size), device=device, dtype=torch.float32)
return (src, tgt)
def forward(self, src_ids, tgt_ids):
param_devices = [p.device for p in self.module.parameters()]
return self.module(src_ids)
def get_loss(self, input, output):
_, tgt = input
return nn.functional.binary_cross_entropy_with_logits(output, tgt)
def run_backward(self, loss):
loss.backward()
KEYS = ["ssd_offload", "flatten_parameters", "mixed_precision", "move_params_to_cpu"]
CONFIG = [[dict(zip(KEYS, config))] for config in itertools.product([True, False], repeat=len(KEYS))]
class TimeKeeper:
def __init__(self):
self.start_time = time.time()
def print_time(self, s: str, wait_time: float = 1.0):
cur_time = time.time()
print(f"@time: {cur_time - self.start_time:0.2f} {s}")
time.sleep(wait_time)
class TestModuleProperties(DistributedTest):
@parameterized.expand(CONFIG, name_func=rename_test)
def test_named_parameters(self, config):
test_fn = functools.partial(self._test_named_params, config=config)
spawn_and_init(test_fn)
@classmethod
def _test_named_params(self, rank, group, config):
# Get the named parameters before wrapping.
before_wrap_model = TransformerWithSharedParams(group)
before_wrap_params = before_wrap_model.named_parameters()
with tempfile.TemporaryDirectory() as current_tempdir:
if config["ssd_offload"]:
config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
# ssd offload only supports flatten_params ATM
config["flatten_parameters"] = True
del config["ssd_offload"]
model = FullyShardedDataParallel(before_wrap_model, **config)
print(f"model.ssd_offload {model.ssd_offload}")
if not model.ssd_offload and not model.move_params_to_cpu:
model = model.cuda()
self._eval_with_config(model, autocast=config["mixed_precision"])
# Get the named parameters after wrapping to compare.
after_wrap_params = model.named_parameters()
if not config.get("flatten_parameters", False):
for before_nm, after_nm in zip(before_wrap_params, after_wrap_params):
assert before_nm[0] == after_nm[0]
else:
named_params_flat = [p for p in after_wrap_params][0][0]
assert "flat_param_0" in named_params_flat
after_wrap_params = model.named_parameters()
for before_nm, after_nm_original in zip(before_wrap_params, after_wrap_params):
assert before_nm[0] == after_nm_original[0]
torch.testing.assert_allclose(before_nm[1].shape, after_nm_original[1].shape)
class TestSsdLoading(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_ssd_offloading_eval(self, config):
test_fn = functools.partial(self._test_ssd_offload_eval, config=config)
spawn_and_init(test_fn)
@parameterized.expand(CONFIG, name_func=rename_test)
def test_transformer_parameterized(self, config):
spawn_and_init(functools.partial(self._test_identical_outputs_eval, TransformerWithSharedParams, config))
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_ssd_offloading_train_flatten_params_wrapper(self, config):
test_fn = functools.partial(self._test_ssd_offloading_train_flatten_params_wrapper, config=config)
spawn_and_init(test_fn)
@classmethod
def _test_ssd_offloading_train_flatten_params_wrapper(self, rank, group, config):
SIZE = 16 * 16
LR = 0.01
MOMENTUM = 0.1
model = SimpleLinear(group, input_size=SIZE, output_size=SIZE, layers=4)
with tempfile.TemporaryDirectory() as current_tempdir:
config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
config["flatten_parameters"] = True
nested_wrapping = config["nested_wrapping"]
del config["nested_wrapping"]
if nested_wrapping:
model = FullyShardedDataParallel(
NestedWrappedModule(group, wrap_everything=True, wrapper_config=config)
)
else:
model = FullyShardedDataParallel(model, **config)
model_device = torch.device("cuda")
model.train()
optim = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
checkpoint_file = tempfile.NamedTemporaryFile()
checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir")
pre_checkpoint_last_output = None
post_checkpoint_last_output = None
ITERATIONS = 10
# Inputs always cuda regardless of move_grads_cpu, or model.device
with torch.cuda.amp.autocast(enabled=config.get("mixed_precision", False)):
for i in range(ITERATIONS):
optim.zero_grad()
input = model.get_input(torch.device("cuda"))
output = model(*input)
pre_checkpoint_last_output = output
"""
param_itr = iter(model.named_parameters())
p_name, p_val = next(param_itr)
print(f"i={i} pre_checkpoint {p_name} = {p_val[0].item()}")
"""
loss = model.module.get_loss(input, output).to(model_device)
assert loss.dtype == torch.float32
model.module.run_backward(loss)
optim.step()
if i == 0:
with so.CheckpointPathContextManager(override_path=checkpoint_load_directory.name):
# so.torch_saver.save({"model": model.state_dict(), "optim": optim.state_dict()}, checkpoint_file.name)
torch.save({"model": model.state_dict()}, checkpoint_file.name)
# reset momentum just after checkpoint save
optim = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
checkpoint = torch.load(checkpoint_file.name)
model.load_state_dict(checkpoint["model"])
# reset momentum just after checkpoint load
optim = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
# do more iterations after loading checkpoint
for i in range(ITERATIONS - 1):
optim.zero_grad()
input = model.get_input(torch.device("cuda"))
output = model(*input)
post_checkpoint_last_output = output
"""
param_itr = iter(model.named_parameters())
p_name, p_val = next(param_itr)
print(f"i={i} post_checkpoint {p_name} = {p_val[0].item()}")
"""
loss = model.module.get_loss(input, output).to(model_device)
assert loss.dtype == torch.float32
model.module.run_backward(loss)
optim.step()
# Verify output of checkpoint load + run is equal to original output
assert torch.equal(pre_checkpoint_last_output, post_checkpoint_last_output)
if isinstance(model, FullyShardedDataParallel):
model.assert_state(TrainingState.IDLE)
@classmethod
def _test_ssd_offload_eval(self, rank, group, config):
model = TransformerWithSharedParams(group)
state_dict = model.state_dict()
nested_wrapping = config["nested_wrapping"]
del config["nested_wrapping"]
config["flatten_parameters"] = True
with tempfile.TemporaryDirectory() as current_tempdir:
config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
if nested_wrapping:
model = FullyShardedDataParallel(
NestedWrappedModule(group, wrap_everything=True, wrapper_config=config)
)
else:
model = FullyShardedDataParallel(model, **config)
self._eval_with_config(model, autocast=config["mixed_precision"])
# With SSD offload only local_state_dict will work. We can support global
# state dict if we think it is necessary.
state_dict = model.local_state_dict()
model.load_local_state_dict(state_dict)
self._eval_with_config(model, config["mixed_precision"])
class TransformerWithSharedParams(nn.Module):
def __init__(self, group, *unused_args, d_vocab=23, d_model=16, add_bn=True, **unused_kwargs):
super().__init__()
self.rank = group.rank()
self.world_size = group.size()
torch.manual_seed(0) # keep everything deterministic
assert d_vocab >= 12 # we use torch.arange(12) as input
self.embed_tokens = nn.Embedding(d_vocab, d_model)
self.transformer = nn.Transformer(
d_model=d_model,
num_encoder_layers=2,
num_decoder_layers=2,
dim_feedforward=8,
dropout=0.1,
)
self.output_proj = nn.Linear(d_model, d_vocab)
# share the embedding and output projection weights
self.output_proj.weight = self.embed_tokens.weight
self.register_buffer("vocab_bias", self.embed_tokens.weight.new_ones((d_model,)))
self.register_buffer("long_buffer", torch.zeros_like(self.vocab_bias, dtype=torch.long))
self.bs = 2
self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity()
def get_input(self, device):
torch.manual_seed(1 + self.rank) # keep everything deterministic
src = torch.arange(12, device=device).view(6, self.bs) # T x B
tgt = torch.arange(self.bs * 4, device=device).view(4, self.bs) # T x B
return (src, tgt)
def forward(self, src_ids, tgt_ids):
src = self.embed_tokens(src_ids)
src = src + self.vocab_bias + self.long_buffer.type_as(src)
tgt = self.embed_tokens(tgt_ids)
tgt = self.bn(tgt)
x = self.transformer(src, tgt)
return self.output_proj(x)
def get_loss(self, input, output):
_, tgt = input
return nn.functional.cross_entropy(output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum")
def run_backward(self, loss):
loss.backward()
class NestedWrappedModule(nn.Module):
def __init__(self, group, wrapper_config, wrap_everything=False, checkpoint=False):
super().__init__()
self.rank = group.rank()
self.world_size = group.size()
self.wrapper_config = wrapper_config
def _maybe_wrap(layer):
if wrapper_config is not None:
return FullyShardedDataParallel(layer, group, **wrapper_config)
return layer
torch.manual_seed(0) # keep everything deterministic
self.module = nn.Sequential(
nn.Linear(8, 4),
_maybe_wrap(
nn.Sequential(
_maybe_wrap(nn.Linear(4, 16)),
nn.Linear(16, 16),
)
),
_maybe_wrap(nn.Linear(16, 4)),
nn.Linear(4, 8),
)
# Wrap all modules triggers a corner case where root FSDP doesn't have any params.
# Test it with checkpoint_wrapper as well to validate final backward callback
# is queued correctly when root FSDP does not have any params and every layer is
# wrapped as FSDP(checkpoint(module)).
if wrap_everything:
if checkpoint:
self.module = nn.Sequential(
_maybe_wrap(checkpoint_wrapper(nn.Linear(8, 4))),
_maybe_wrap(checkpoint_wrapper(nn.Linear(4, 16))),
_maybe_wrap(checkpoint_wrapper(nn.Linear(16, 4))),
_maybe_wrap(checkpoint_wrapper(nn.Linear(4, 8))),
)
else:
self.module = nn.Sequential(
_maybe_wrap(nn.Linear(8, 4)),
_maybe_wrap(nn.Linear(4, 16)),
_maybe_wrap(nn.Linear(16, 4)),
_maybe_wrap(nn.Linear(4, 8)),
)
def get_input(self, device):
torch.manual_seed(1 + self.rank) # keep everything deterministic
return (torch.rand(4, 8, device=device),)
def forward(self, x):
return self.module(x)
def get_loss(self, input, output):
loss = output.sum()
return loss
def run_backward(self, loss):
loss.backward()
def spawn_and_init(fn, args=None, **spawn_kwargs):
if args is None:
args = ()
run_fn = functools.partial(init_and_run, fn, args)
# Below 3 lines are to easily enable single-process debugging
# _, filename = tempfile.mkstemp()
# _, filename_rpc = tempfile.mkstemp()
# run_fn(0, 1, filename, filename_rpc)
spawn_for_all_world_sizes(run_fn, **spawn_kwargs)
def init_and_run(fn, args, rank, world_size, filename, filename_rpc):
dist_init(rank, world_size, filename, filename_rpc)
group = torch.distributed.new_group()
fn(rank, group, *args)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment