Unverified Commit d7c4aa52 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[feature]Add support for SSD offload with FSDP for eval workloads (#839)

* update release notes

* initial commit

* lint cleanup etc.

* helper functions; lint errors

* lint errors

* lint errors

* add back the boolean for named_parameters

* address comments and fix lint

* remove unused functions and class

* remove unused state
parent 21464e05
......@@ -9,10 +9,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- FSDP: Fixed an pre-backward hook bug for certain type of models and FSDP config. [#833]
### Added
- LayerwiseMemoryTracker[feature][experimental] - This is a new experimental tool to help track, visualize and suggest fix for memory issues occurring during the forward/backward pass of your models. [#808]
- [FSDP]: limited support of shared weights between FSDP wrappers. This allows large parameter
- FSDP: Add support for SSD offload for eval workloads. This is a new experimental feature and should be
used with caution.
- LayerwiseMemoryTracker[feature][experimental]: This is a new experimental tool to help track, visualize and suggest fix for memory issues occurring during the forward/backward pass of your models. [#808]
- FSDP: limited support of shared weights between FSDP wrappers. This allows large parameter
and gradient memory to be sharded despite being needed from different layers due to
weight sharing. [#836]
- OffloadModel: Fix node names to enable correct sharding in auto_shard.py [#830]
- OSS: Relaxed speed and memory constraints on OSS golden data due to regression when we bumped up the
PyTorch version to 1.9. [#828] [#825]
- Chore: Update PyTorch version that we run benchmarks with. [#823]
- Chore: Update PyTorch version that we run test with. [#809]
- OffloadModel: Extend auto_shard.py to allow dealing with conditionals automatically when tracing with
torch.fx. This will work for most cases except when the conditional is part of the root instance. [#817]
- [MEVO]: a custom layer to help big vocab trainings. Experimental. Docs is still TBD. [#840]
- SlowMoDistributedDataParallel[feature][experimental] - This is a distributed training wrapper which should be useful on clusters with slow network interconnects (eg Ethernet). This improves on performance as compared to Distributed Data Parallel in such clusters. [#378]
......
......@@ -5,6 +5,7 @@
from __future__ import annotations
from enum import Enum, auto
from functools import reduce
import io
import os
......@@ -68,6 +69,19 @@ def read(input_tensor: torch.Tensor, filename: str, file_offset_bytes: int = 0)
assert data_read == chunk_end - chunk_start
class StorageState(Enum):
"""
Simple enum to indicate whether the tensor handle is pointing
to data on disk or memory. This is useful for asserting on
whether the tensor is available for operations or if it needs
to be moved from disk to CPU or device.
"""
UNALLOCATED = auto()
ON_DISK = auto()
ON_CPU = auto()
class SsdTensorHandle(torch.Tensor):
"""
This class extends from torch.Tensor and represents a Tensor which is backed by SSD storage.
......@@ -104,6 +118,7 @@ class SsdTensorHandle(torch.Tensor):
# valid if loaded to memory
self.tensor: Optional[torch.Tensor] = None
self.requires_grad = requires_grad
self.storage_state = StorageState.UNALLOCATED
@classmethod
def from_file(
......@@ -112,6 +127,7 @@ class SsdTensorHandle(torch.Tensor):
"""Returns a new SsdTensorHandle from a file."""
handle = cls(shape=shape, dtype=dtype, requires_grad=requires_grad)
handle.filename = filename
handle.storage_state = StorageState.ON_DISK
return handle
@classmethod
......@@ -119,6 +135,7 @@ class SsdTensorHandle(torch.Tensor):
"""Returns a new SsdTensorHandle from a tensor."""
handle = cls(shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad)
handle.tensor = tensor
handle.storage_state = StorageState.ON_CPU
return handle
def is_available(self) -> bool:
......@@ -153,6 +170,7 @@ class SsdTensorHandle(torch.Tensor):
result_tensor = torch.empty(size=self._shape, dtype=self._dtype, requires_grad=self.requires_grad)
self.copy_into_tensor(result_tensor)
self.tensor = result_tensor
self.storage_state = StorageState.ON_CPU
return self.tensor
def to_file(self, release_tensor_after_write: bool = True) -> None:
......@@ -161,6 +179,7 @@ class SsdTensorHandle(torch.Tensor):
write(self.tensor, self.filename, self.offset * self.tensor.element_size())
if release_tensor_after_write:
self.tensor = None
self.storage_state = StorageState.ON_DISK
def copy_into_tensor(self, tensor: torch.Tensor) -> None:
"""Copies SsdTensorHandle's data into the given tensor.
......@@ -225,11 +244,12 @@ class SsdBuffer:
self.filename = filename
self.offset = 0
self.tensors: Dict[int, SsdTensorHandle] = {}
self.storage_state = StorageState.ON_CPU
def allocate(self, num_elems: int) -> SsdTensorHandle:
"""Allocates a new tensor handle of size num_elems."""
assert num_elems > 0
assert list(self.buffer.size()) != [1]
assert self.storage_state == StorageState.ON_CPU, self.storage_state
assert self.can_alloc(num_elems)
tensor = self.buffer.narrow(0, self.offset, num_elems)
......@@ -244,7 +264,7 @@ class SsdBuffer:
def insert(self, tensor: torch.Tensor) -> SsdTensorHandle:
"""Insert a new tensor by allocating memory and creating a corresponding handle."""
assert list(self.buffer.size()) != [1]
assert self.storage_state == StorageState.ON_CPU, self.storage_state
# For the non sharded case, the tensor will not be flattened
tensor = tensor.reshape(-1)
assert self.buffer.dtype == tensor.dtype
......@@ -255,7 +275,7 @@ class SsdBuffer:
def can_alloc(self, num_elems: int) -> bool:
"""Verify that you can allocate a tensor within the bounds
of the larger SsdBuffer memory buffer."""
assert list(self.buffer.size()) != [1]
assert self.storage_state == StorageState.ON_CPU, self.storage_state
return (self.offset + num_elems) <= self.buffer.numel()
def get_tensors(self) -> List[SsdTensorHandle]:
......@@ -264,8 +284,11 @@ class SsdBuffer:
def to_disk(self) -> None:
"""Writes all tensors backed by handles to disk."""
assert list(self.buffer.size()) != [1]
# TODO(anj): Add comment about why we use `narrow`.
if self.storage_state == StorageState.ON_DISK:
return
assert self.storage_state == StorageState.ON_CPU, self.storage_state
# We use `narrow` so that we write valid tensors that have been allocated
# as opposed to the entire SSD buffer.
valid_data = self.buffer.narrow(0, 0, self.offset)
write(valid_data, self.filename)
......@@ -276,9 +299,13 @@ class SsdBuffer:
# TODO(anj-s): Setting this to None does not result in GC picking
# this reference up.
self.buffer = torch.empty((1))
self.storage_state = StorageState.ON_DISK
def from_disk(self, num_elems: int, dtype: torch.dtype = torch.float32) -> None:
"""Reads all tensors backed by handles into memory."""
if self.storage_state == StorageState.ON_CPU:
return
assert self.storage_state == StorageState.ON_DISK, self.storage_state
if num_elems < self.offset:
raise RuntimeError(
f"Attempted to load from file ssdbuffer of size: {self.offset} into a buffer that is of size: {num_elems}"
......@@ -289,3 +316,5 @@ class SsdBuffer:
for offset, t in self.tensors.items():
t.point_to_tensor(self.buffer.narrow(0, t.offset, t._numel))
self.storage_state = StorageState.ON_CPU
......@@ -10,6 +10,7 @@ import functools
import logging
from math import inf
import os
from random import randint
import time
import traceback
import typing
......@@ -61,6 +62,15 @@ if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
else:
enable_nccl_base_collectives = True
try:
import fairscale.experimental.nn.ssd_offload as ssd_offload
import_ssd_offload = True
except ImportError:
# The latest nightly PyTorch version required
import_ssd_offload = False
pass
class TrainingState(Enum):
"""
......@@ -272,6 +282,7 @@ class FullyShardedDataParallel(nn.Module):
force_input_to_fp32: bool = False,
verbose: bool = False,
cpu_offload: bool = False,
**kwargs: Dict[str, Any],
):
init_start = time.time()
super().__init__()
......@@ -294,6 +305,8 @@ class FullyShardedDataParallel(nn.Module):
self.clear_autocast_cache = clear_autocast_cache
self.force_input_to_fp32 = force_input_to_fp32
self.verbose = verbose
# Experimental feature for now. Use at your own risk.
self.ssd_offload = kwargs.get("ssd_offload", False)
self.gradient_predivide_factor: float = self._get_gradient_predivide_factor(self.world_size)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
......@@ -323,6 +336,18 @@ class FullyShardedDataParallel(nn.Module):
self._has_params = len(params) > 0
# TODO(anj): Should we conditionally do this only if we have params?
# TODO(anj): Figure out if we can allocate the buffer during sharding.
self.buffer_size = sum(p.numel() for p in params)
self.ssd_buffer_filename = ""
if self.ssd_offload:
assert import_ssd_offload, "We need to import ssd_offload.py to enable the `ssd_offload` feature."
# TODO(anj): Add support for temp file and directory as possible API params.
self.ssd_buffer_filename = f"{randint(1, int(10E6))}_rank{self.rank}"
self.ssd_buffer = ssd_offload.SsdBuffer(self.buffer_size, self.ssd_buffer_filename)
self.move_grads_to_cpu = True
self.move_params_to_cpu = True
# For now, it is either all flatten or none flatten. This will be extended to
# multiple flatten groups in my next PR.
to_be_flatten_params: List[List[Parameter]] = [[]]
......@@ -375,6 +400,8 @@ class FullyShardedDataParallel(nn.Module):
# Flag to indicate whether state_dict() should automatically summon the
# full params. This defaults to True, but may be set to False if the
# user explicitly requests the local state dict via local_state_dict().
# TODO(anj): This should by default be set to False for ssd_offload=True
# unless we are in the summon_full_params context.
self._return_full_state_dict = True
init_end = time.time()
......@@ -386,6 +413,12 @@ class FullyShardedDataParallel(nn.Module):
# This is reset at the end of the backward pass.
self._pre_backward_hook_has_run = False
# Free all params at the end of initialization.
if self.ssd_offload:
for m in self.modules(): # includes self
if isinstance(m, FullyShardedDataParallel):
m._free_ssd_offload()
def _get_gradient_predivide_factor(self, world_size: int) -> float:
factor: int = 1
while world_size % factor == 0 and world_size / factor > factor:
......@@ -614,20 +647,42 @@ class FullyShardedDataParallel(nn.Module):
p._orig_size = p.data.size()
if not p._is_sharded:
if self.ssd_offload:
# Insert tensor into the SSD buffer and free parameter storage.
p._is_sharded = False
self.numel_padded_per_param.append(0)
p._shard_size = p.data.size() # type: ignore
p._handle = self.ssd_buffer.insert(p.data) # type: ignore
free_storage_(p.data)
continue
else:
p._is_sharded = False
self.numel_padded_per_param.append(0)
continue
p._is_sharded = True
# Replace p.data with the relevant shard.
if self.ssd_offload:
orig_data = p.data
p.data, num_padded = self._get_shard(p.data)
p._shard_size = p.data.size() # type: ignore
# Insert tensor into the SSD buffer and free parameter storage.
p._handle = self.ssd_buffer.insert(p.data) # type: ignore
del orig_data
self.numel_padded_per_param.append(num_padded)
free_storage_(p.data)
else:
orig_data = p.data
p.data, num_padded = self._get_shard(p.data)
self.numel_padded_per_param.append(num_padded)
free_storage_(orig_data)
p._is_sharded = True
assert len(self.numel_padded_per_param) == len(self.params)
# Move SSD buffer to disk.
if self.ssd_offload:
self.ssd_buffer.to_disk()
def _get_shard(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, int]:
"""Return the local shard of a full tensor."""
# Shard using torch.chunk to match all-gather/reduce-scatter.
......@@ -704,6 +759,18 @@ class FullyShardedDataParallel(nn.Module):
del self.orig_sizes
self._reset_lazy_init()
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
"""Returns an iterator over the module parameters, yielding all the parameters
part of the model.
"""
# TODO(anj): Use `copy_into_tensor` in order to provide a copy of the
# parameters and not the actual parameters. Ideally we don't users to operate on
# actual params.
if self.ssd_offload:
self.ssd_buffer.from_disk(self.buffer_size)
return super().parameters(recurse=recurse)
def named_parameters(self, *args: Any, **kwargs: Any) -> Iterator[Tuple[str, Parameter]]:
"""Returns an iterator over the module parameters, yielding both the name of the
parameter as well as the parameter.
......@@ -715,6 +782,12 @@ class FullyShardedDataParallel(nn.Module):
If you want the full param to be returned, you should call this function
under a `summon_full_params` context when using flattened or original params.
"""
# TODO(anj): Use `copy_into_tensor` in order to provide a copy of the
# parameters and not the actual parameters. Ideally we don't users to operate on
# actual params.
if self.ssd_offload:
self.ssd_buffer.from_disk(self.buffer_size)
named_param = super().named_parameters(*args, **kwargs)
for name, param in named_param:
if (
......@@ -810,11 +883,23 @@ class FullyShardedDataParallel(nn.Module):
def _no_return_full_state_dict(self) -> Generator:
backup = self._return_full_state_dict
self._return_full_state_dict = False
if self.ssd_offload:
# Move params from disk to memory before returning the local state dict.
self._move_params_to_memory()
try:
yield
finally:
self._return_full_state_dict = backup
def _move_params_to_memory(self) -> None:
"""Move params from disk to CPU."""
self.ssd_buffer.from_disk(self.buffer_size)
for p, handle in zip(self.params, self.ssd_buffer.get_tensors()):
p.data = handle.get_tensor().view(p._shard_size) # type: ignore
def _load_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
) -> NamedTuple:
......@@ -962,6 +1047,13 @@ class FullyShardedDataParallel(nn.Module):
free_storage_(full_tensor)
self.has_full_params = False
self._use_fp32_param_shard()
if self.ssd_offload:
# Store tensors in the SSD buffer and free param storage.
for p in self.params:
p._shard_size = p.data.size() # type: ignore
p._handle = self.ssd_buffer.insert(p.data) # type: ignore
free_storage_(p.data)
self.ssd_buffer.to_disk()
self.training_state = TrainingState.IDLE
def _reset_lazy_init(self) -> None:
......@@ -1044,6 +1136,9 @@ class FullyShardedDataParallel(nn.Module):
# If we plan to keep the FP32 parameters on CPU, then pinning
# memory allows us to later use non-blocking transfers when moving
# the FP32 param shard to compute_device.
if not self.ssd_offload:
# We don't pin memory when using ssd_offload since that results in OOM when
# the memory requirements of a model are larger than host memory.
p._fp32_shard = p._fp32_shard.pin_memory()
p.data = p._fp32_shard
......@@ -1083,6 +1178,11 @@ class FullyShardedDataParallel(nn.Module):
# pass. In this case, it's important to pre-allocate the CPU grad
# shard in pinned memory so that we can do a non-blocking transfer.
# This is only needed during training and not evaluation.
if self.ssd_offload:
# We don't pin memory when using ssd_offload since that results in OOM when
# the memory requirements of a model are larger than host memory.
p._cpu_grad = torch.zeros_like(p.data, device="cpu")
else:
p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory()
def _set_is_root(self) -> None:
......@@ -1171,6 +1271,9 @@ class FullyShardedDataParallel(nn.Module):
self._streams["all_gather"].wait_stream(torch.cuda.current_stream())
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
if self.ssd_offload:
self._move_params_to_memory()
self._lazy_init()
# Start of a forward pass.
......@@ -1228,8 +1331,15 @@ class FullyShardedDataParallel(nn.Module):
if self.clear_autocast_cache:
torch.clear_autocast_cache()
self._free_ssd_offload()
return outputs
@torch.no_grad()
def _free_ssd_offload(self) -> None:
if self.ssd_offload:
self.ssd_buffer.to_disk()
def _register_pre_backward_hooks(self, outputs: Any) -> Any:
"""Register pre-backward hook to run before the wrapped module's
backward. Hooks should be attached to all outputs from the forward.
......@@ -1268,7 +1378,7 @@ class FullyShardedDataParallel(nn.Module):
# Note, both ``self._rebuild_full_params`` and ``self._use_full_params`` are
# idempotent. So in case they are called unnecessarily, they don't incur much
# overhead.
if self.reshard_after_forward:
if self.ssd_offload or self.reshard_after_forward:
self._rebuild_full_params()
else:
self._use_full_params()
......@@ -1593,6 +1703,7 @@ class FullyShardedDataParallel(nn.Module):
for m in self.modules(): # includes self
if isinstance(m, FullyShardedDataParallel):
_finalize_parameters(m)
self._free_ssd_offload()
m._pre_backward_hook_has_run = False
if any(p.requires_grad for p in m.parameters()):
# Check if the module has params and if any of them has
......@@ -1669,6 +1780,16 @@ class FullyShardedDataParallel(nn.Module):
# Trim any padding and reshape to match original size.
p.data = p.data[: p._orig_size.numel()].view(p._orig_size)
if self.ssd_offload:
self.ssd_buffer.from_disk(self.buffer_size)
# The params are on disk and need to be moved to the CPU.
for p, handle in zip(self.params, self.ssd_buffer.get_tensors()):
p._fp32_shard = handle.get_tensor().view(p._shard_size) # type: ignore
p.data = p._fp32_shard
self.has_full_params = False
# Early exit if we already have full params and don't need full precision.
if self.has_full_params and not force_full_precision:
for p in self.params:
......
......@@ -52,3 +52,4 @@ tests/experimental/nn/test_auto_shard.py
tests/experimental/optim/test_dynamic_loss_scaler.py
tests/experimental/tooling/test_layer_memory_tracker.py
tests/experimental/nn/test_ssd_offload.py
tests/nn/data_parallel/test_fsdp_offload.py
......@@ -155,9 +155,6 @@ def test_ssd_buffer_null_buffer():
hdl_a = ssd_buf.insert(refa_tensor)
ssd_buf.to_disk()
with pytest.raises(AssertionError):
ssd_buf.to_disk()
with pytest.raises(AssertionError):
hdl_a = ssd_buf.insert(refa_tensor)
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
import glob
import itertools
import os
import sys
import time
import unittest
from parameterized import parameterized
import pytest
import torch
from torch import nn
import torch.distributed
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel, TrainingState
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, rmf, spawn_for_all_world_sizes
# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
pytestmark = pytest.mark.skipif(torch_version() < (1, 11, 0), reason="requires torch version >= 1.11.0")
# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# All helper functions called by spawn must be either @classmethod, @staticmethod
class DistributedTest(unittest.TestCase):
def setUp(self):
if torch_version() < (1, 6, 0):
raise unittest.SkipTest("Need pytorch version >= 1.6 due to lack of reduce_scatter")
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA not available, skipping test")
if sys.platform == "win32":
raise unittest.SkipTest("NCCL doesn't support Windows, skipping test")
if torch.cuda.device_count() < 2:
raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping")
@staticmethod
def _eval_with_config(model, autocast):
model.eval()
model_device = torch.device("cuda")
with torch.cuda.amp.autocast(enabled=autocast):
# Inputs always cuda regardless of move_grads_cpu, or model.device
input = model.module.get_input(torch.device("cuda"))
output = model(*input)
loss = model.module.get_loss(input, output).to(model_device)
assert loss.dtype == torch.float32
if isinstance(model, FullyShardedDataParallel):
model.assert_state(TrainingState.IDLE)
return loss.detach()
@staticmethod
def _eval_for_several_steps(model, num_steps, autocast, lr=0.01, norm_type=None):
model.eval()
# Inputs always cuda regardless of move_grads_cpu, or model.device
input = model.module.get_input(torch.device("cuda"))
for _ in range(num_steps):
with torch.cuda.amp.autocast(enabled=autocast):
output = model(*input)
@classmethod
def _test_identical_outputs_eval(
cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None,
):
if config.get("mixed_precision", False):
autocast = True
# Force the compute dtype to be torch.float32 so that we get
# identical results as PyTorch DDP when using autocast. Note that
# this will cause the all-gather to happen in FP32, which is slower
# than necessary in most cases.
config["compute_dtype"] = torch.float32
else:
autocast = False
# Establish reference behavior with PyTorch DDP (+ optionally autocast).
model = model_init_fn(group=group, wrapper_config=None).cuda()
if ref_ddp_fn is None:
model = nn.parallel.DistributedDataParallel(
model, device_ids=[rank], output_device=rank, process_group=group
)
else:
model = ref_ddp_fn(model, group)
ref_loss = cls._eval_with_config(model, autocast)
ref_state_dict = model.module.state_dict()
if config.get("cpu_offload", False):
for k in ref_state_dict.keys():
ref_state_dict[k] = ref_state_dict[k].cpu()
# Confirm we get the same behavior using FullyShardedDataParallel.
model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
if not config.get("ssd_offload", False):
if use_cuda:
model = model.cuda()
else:
assert next(model.parameters()).device == torch.device("cpu")
shard_loss = cls._eval_with_config(model, autocast)
try:
torch.testing.assert_allclose(ref_loss, shard_loss)
except (AssertionError, RuntimeError) as e:
raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}")
if config.get("flatten_parameters", True):
metadata = model.local_metadata_dict()
assert isinstance(metadata, dict)
keys = ["reshard_after_forward", "mixed_precision", "flatten_parameters", "nested_wrapping"]
CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([True, False], repeat=len(keys))]
def rename_test(testcase_func, param_num, param):
return "%s_%s" % (testcase_func.__name__, parameterized.to_safe_name(str(param.args)),)
class TestSsdMemory(DistributedTest):
def test_memory_benchmark(self):
test_fn = functools.partial(self._test_memory_benchmark, config={})
spawn_and_init(test_fn)
@classmethod
def _test_memory_benchmark(self, rank, group, config):
time_keeper = TimeKeeper()
SIZE = 8 * 8
time_keeper.print_time("START", 1.0)
a = torch.empty(1)
b = a.cuda()
# wait for cuda to fully load
time.sleep(1)
time_keeper.print_time("INIT_CUDA", 1.0)
model = SimpleLinear(group, input_size=SIZE, output_size=SIZE, layers=4)
time_keeper.print_time("CPU_MODEL", 1.0)
config["ssd_offload"] = True
model = FullyShardedDataParallel(model, **config)
time_keeper.print_time("FSDP_MODEL", 1.0)
self._eval_for_several_steps(model, 1, autocast=False)
time_keeper.print_time("EVAL")
fileList = glob.glob(os.getcwd() + "/*_rank*")
for file in fileList:
rmf(file)
class SimpleLinear(nn.Module):
def __init__(self, group, input_size, output_size, layers=1, **unused_kwargs):
super().__init__()
self.rank = group.rank()
self.world_size = group.size()
self.input_size = input_size
self.output_size = output_size
torch.manual_seed(0) # keep everything deterministic
seq_layers = []
for i in range(layers):
seq_layers.append(nn.Linear(input_size, output_size, bias=False))
self.module = nn.Sequential(*seq_layers)
self.bs = 2
def get_input(self, device):
torch.manual_seed(1 + self.rank) # keep everything deterministic
src = torch.rand((self.bs, self.input_size), device=device, dtype=torch.float32)
tgt = torch.rand((self.bs, self.input_size), device=device, dtype=torch.float32)
return (src, tgt)
def forward(self, src_ids, tgt_ids):
param_devices = [p.device for p in self.module.parameters()]
return self.module(src_ids)
def get_loss(self, input, output):
_, tgt = input
return nn.functional.binary_cross_entropy_with_logits(output, tgt)
def run_backward(self, loss):
loss.backward()
KEYS = ["ssd_offload", "flatten_parameters", "mixed_precision", "move_params_to_cpu"]
CONFIG = [[dict(zip(KEYS, config))] for config in itertools.product([True, False], repeat=len(KEYS))]
class TimeKeeper:
def __init__(self):
self.start_time = time.time()
def print_time(self, s: str, wait_time: float = 1.0):
cur_time = time.time()
print(f"@time: {cur_time - self.start_time:0.2f} {s}")
time.sleep(wait_time)
class TestModuleProperties(DistributedTest):
@parameterized.expand(CONFIG, name_func=rename_test)
def test_named_parameters(self, config):
test_fn = functools.partial(self._test_named_params, config=config)
spawn_and_init(test_fn)
@classmethod
def _test_named_params(self, rank, group, config):
# Get the named parameters before wrapping.
before_wrap_model = TransformerWithSharedParams(group)
before_wrap_params = before_wrap_model.named_parameters()
config["ssd_offload"] = True
model = FullyShardedDataParallel(before_wrap_model, **config)
if not config["ssd_offload"]:
model = model.cuda()
self._eval_with_config(model, autocast=config["mixed_precision"])
# Get the named parameters after wrapping to compare.
after_wrap_params = model.named_parameters()
if not config.get("flatten_parameters", False):
for before_nm, after_nm in zip(before_wrap_params, after_wrap_params):
assert before_nm[0] == after_nm[0]
else:
named_params_flat = [p for p in after_wrap_params][0][0]
assert "flat_param_0" in named_params_flat
after_wrap_params = model.named_parameters()
for before_nm, after_nm_original in zip(before_wrap_params, after_wrap_params):
assert before_nm[0] == after_nm_original[0]
torch.testing.assert_allclose(before_nm[1].shape, after_nm_original[1].shape)
class TestSsdLoading(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_ssd_offloading_eval(self, config):
test_fn = functools.partial(self._test_ssd_offload_eval, config=config)
spawn_and_init(test_fn)
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_transformer_parameterized(self, config):
spawn_and_init(functools.partial(self._test_identical_outputs_eval, TransformerWithSharedParams, config))
@classmethod
def _test_ssd_offload_eval(self, rank, group, config):
model = TransformerWithSharedParams(group)
state_dict = model.state_dict()
nested_wrapping = config["nested_wrapping"]
del config["nested_wrapping"]
config["ssd_offload"] = True
if nested_wrapping:
model = FullyShardedDataParallel(NestedWrappedModule(group, wrap_everything=True, wrapper_config=config))
else:
model = FullyShardedDataParallel(model, **config)
if not config["ssd_offload"]:
model = model.cuda()
self._eval_with_config(model, autocast=config["mixed_precision"])
# With SSD offload only local_state_dict will work. We can support global
# state dict if we think it is necessary.
state_dict = model.local_state_dict()
model.load_local_state_dict(state_dict)
self._eval_with_config(model, config["mixed_precision"])
fileList = glob.glob(os.getcwd() + "/*_rank*")
for file in fileList:
rmf(file)
class TransformerWithSharedParams(nn.Module):
def __init__(self, group, *unused_args, d_vocab=23, d_model=16, add_bn=True, **unused_kwargs):
super().__init__()
self.rank = group.rank()
self.world_size = group.size()
torch.manual_seed(0) # keep everything deterministic
assert d_vocab >= 12 # we use torch.arange(12) as input
self.embed_tokens = nn.Embedding(d_vocab, d_model)
self.transformer = nn.Transformer(
d_model=d_model, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=8, dropout=0.1,
)
self.output_proj = nn.Linear(d_model, d_vocab)
# share the embedding and output projection weights
self.output_proj.weight = self.embed_tokens.weight
self.register_buffer("vocab_bias", self.embed_tokens.weight.new_ones((d_model,)))
self.register_buffer("long_buffer", torch.zeros_like(self.vocab_bias, dtype=torch.long))
self.bs = 2
self.bn = torch.nn.BatchNorm1d(self.bs) if add_bn else torch.nn.Identity()
def get_input(self, device):
torch.manual_seed(1 + self.rank) # keep everything deterministic
src = torch.arange(12, device=device).view(6, self.bs) # T x B
tgt = torch.arange(self.bs * 4, device=device).view(4, self.bs) # T x B
return (src, tgt)
def forward(self, src_ids, tgt_ids):
src = self.embed_tokens(src_ids)
src = src + self.vocab_bias + self.long_buffer.type_as(src)
tgt = self.embed_tokens(tgt_ids)
tgt = self.bn(tgt)
x = self.transformer(src, tgt)
return self.output_proj(x)
def get_loss(self, input, output):
_, tgt = input
return nn.functional.cross_entropy(output.view(-1, output.size(-1)), tgt.view(-1), reduction="sum")
def run_backward(self, loss):
loss.backward()
class NestedWrappedModule(nn.Module):
def __init__(self, group, wrapper_config, wrap_everything=False, checkpoint=False):
super().__init__()
self.rank = group.rank()
self.world_size = group.size()
self.wrapper_config = wrapper_config
def _maybe_wrap(layer):
if wrapper_config is not None:
return FullyShardedDataParallel(layer, group, **wrapper_config)
return layer
torch.manual_seed(0) # keep everything deterministic
self.module = nn.Sequential(
nn.Linear(8, 4),
_maybe_wrap(nn.Sequential(_maybe_wrap(nn.Linear(4, 16)), nn.Linear(16, 16),)),
_maybe_wrap(nn.Linear(16, 4)),
nn.Linear(4, 8),
)
# Wrap all modules triggers a corner case where root FSDP doesn't have any params.
# Test it with checkpoint_wrapper as well to validate final backward callback
# is queued correctly when root FSDP does not have any params and every layer is
# wrapped as FSDP(checkpoint(module)).
if wrap_everything:
if checkpoint:
self.module = nn.Sequential(
_maybe_wrap(checkpoint_wrapper(nn.Linear(8, 4))),
_maybe_wrap(checkpoint_wrapper(nn.Linear(4, 16))),
_maybe_wrap(checkpoint_wrapper(nn.Linear(16, 4))),
_maybe_wrap(checkpoint_wrapper(nn.Linear(4, 8))),
)
else:
self.module = nn.Sequential(
_maybe_wrap(nn.Linear(8, 4)),
_maybe_wrap(nn.Linear(4, 16)),
_maybe_wrap(nn.Linear(16, 4)),
_maybe_wrap(nn.Linear(4, 8)),
)
def get_input(self, device):
torch.manual_seed(1 + self.rank) # keep everything deterministic
return (torch.rand(4, 8, device=device),)
def forward(self, x):
return self.module(x)
def get_loss(self, input, output):
loss = output.sum()
return loss
def run_backward(self, loss):
loss.backward()
def spawn_and_init(fn, args=None, **spawn_kwargs):
if args is None:
args = ()
run_fn = functools.partial(init_and_run, fn, args)
spawn_for_all_world_sizes(run_fn, **spawn_kwargs)
def init_and_run(fn, args, rank, world_size, filename, filename_rpc):
dist_init(rank, world_size, filename, filename_rpc)
group = torch.distributed.new_group()
fn(rank, group, *args)
if __name__ == "__main__":
unittest.main()
......@@ -223,17 +223,17 @@ class TestStateDictDeviceDtype(DistributedTest):
autocast = fsdp_model.mixed_precision or pure_fp16
self._train_for_several_steps(fsdp_model, 1, autocast)
sd = fsdp_model.state_dict()
state_dict = fsdp_model.state_dict()
sd_device = config.get("state_dict_device")
for k, v in sd.items():
if config["cpu_offload"] or (sd_device is not None and sd_device.type == "cpu"):
state_dict_device = config.get("state_dict_device")
for k, v in state_dict.items():
if config["cpu_offload"] or (state_dict_device is not None and state_dict_device.type == "cpu"):
assert v.device.type == "cpu", v.device.type
else:
assert v.device.type == "cuda", v.device.type
expected_dtype = torch.float16 if pure_fp16 else torch.float32
for k, v in sd.items():
for k, v in state_dict.items():
if not torch.is_floating_point(v):
continue
assert v.dtype == expected_dtype, f"{v.dtype} != {expected_dtype}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment