Unverified Commit c5e471bc authored by Paul Johnson's avatar Paul Johnson Committed by GitHub
Browse files

Enabling ssd_offload training basic tests. (#887)

* Enabling ssd_offload training and test via tests/nn/data_parallel/test_fsdp_offload.py.
* Removed unused classes: SsdBuffer, SsdTensorHandleView, SsdParameter, SsdTensor
* Enhance test coverage of test_ssd_offloading_train_flatten_params_wrapper
* Modifications from PR #887 review comments.
* Update Changelog
parent 541bb8c9
......@@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- Fixed a corner case of FSDP init order and losing one of the flags [#880]
- FSDP: Adding basic training support for SSD Offload, it now only supports flattened parameters. Renamed OffloadConfig.ssd_filepath_dir to more generic OffloadConfig.dir. SSD Offload remains an experimental feature. [#887]
## [0.4.3] - 2021-11-18
......
......@@ -9,7 +9,7 @@ from enum import Enum, auto
from functools import reduce
import io
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Iterator, List, Optional, Sequence, Tuple, Type
import numpy as np
import torch
......@@ -131,7 +131,7 @@ class SsdTensorHandle(torch.Tensor):
return handle
@classmethod
def from_tensor(cls, tensor: torch.Tensor) -> SsdTensorHandle:
def from_tensor(cls: Type[SsdTensorHandle], tensor: torch.Tensor) -> SsdTensorHandle:
"""Returns a new SsdTensorHandle from a tensor."""
handle = cls(shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad)
handle.tensor = tensor
......@@ -159,6 +159,13 @@ class SsdTensorHandle(torch.Tensor):
assert self._dtype == tensor.dtype
self.tensor = tensor
# if resizing a handle that is part of an ssd buffer, care must be taken that the new size
# doesn't conflict with adjacent handles!
def point_to_resized_tensor(self, tensor: torch.Tensor) -> None:
assert self._dtype == tensor.dtype
self._shape = tensor.shape
self.tensor = tensor
def to_tensor(self) -> torch.Tensor:
"""Returns the tensor represented by the SsdTensorHandle object.
......@@ -173,9 +180,11 @@ class SsdTensorHandle(torch.Tensor):
self.storage_state = StorageState.ON_CPU
return self.tensor
def to_file(self, release_tensor_after_write: bool = True) -> None:
def to_file(self, permit_when_tensor_none: bool = False, release_tensor_after_write: bool = True) -> None:
"""Saves the tensor to disk and releases memory if specified."""
assert self.tensor is not None
assert self.tensor is not None or permit_when_tensor_none
if self.tensor is not None:
write(self.tensor, self.filename, self.offset * self.tensor.element_size())
if release_tensor_after_write:
self.tensor = None
......@@ -229,92 +238,96 @@ class SsdTensorHandle(torch.Tensor):
return r
class SsdBuffer:
class SsdFlatParameter(torch.nn.Parameter, SsdTensorHandle):
"""A parameter that is initialized from a list of parameters and can be
turned into a list of views as needed.
"""
The SsdBuffer represents a single buffer containing a list of tensors. Each of the
tensors are represented by a `SsdTensorHandle`.
Args:
num_elems (int): Dictates the size of the 1-D tensor.
dtype (torch.dtype): Dtype of the buffer.
"""
def __init__(self, num_elems: int, filename: str, dtype: torch.dtype = torch.float32) -> None:
self.buffer: torch.Tensor = torch.empty((num_elems,), dtype=dtype)
self.filename = filename
self.offset = 0
self.tensors: Dict[int, SsdTensorHandle] = {}
self.storage_state = StorageState.ON_CPU
def allocate(self, num_elems: int) -> SsdTensorHandle:
"""Allocates a new tensor handle of size num_elems."""
assert num_elems > 0
assert self.storage_state == StorageState.ON_CPU, self.storage_state
assert self.can_alloc(num_elems)
tensor = self.buffer.narrow(0, self.offset, num_elems)
def __new__(
cls, params: Sequence[torch.nn.Parameter], filename: str, requires_grad: bool = True
) -> "SsdFlatParameter":
"""Make an object using the parent's __new__ function."""
# A empty of non-list input doesn't make sense.
if not isinstance(params, (list, tuple)) or len(params) == 0:
raise ValueError("An non-empty list or tuple argument is needed")
# Normally, all items are Parameters. But during pickling, we will have a single
# Tensor as the input and later in __init__, the correct _param_numels and _param_shapes
# are set.
if not all(isinstance(p, (torch.nn.Parameter, torch.Tensor)) for p in params):
raise ValueError("List items need to be Parameter types")
# Flattening involves (1) making a tensor flat (i.e. single dimensional) and (2) making a module
# heirarchy flat (using a single tensor to replace a tree of tensors). Therefore,
# adding back nesting and heirarchy is counter-productive. If nesting is encountered
# in the future, the reasonable thing to do is likely for the top level SsdFlatParameter to
# absorb the nested one and keep the result flat, free from hierarchy.
if any(isinstance(p, SsdFlatParameter) for p in params):
raise ValueError("Nesting SsdFlatParameter is not supported")
dtype = params[0].dtype
size = sum(p.numel() for p in params)
r = SsdTensorHandle._make_wrapper_subclass(cls, (size,), dtype=dtype, requires_grad=requires_grad) # type: ignore
return r
tensor_offset = self.offset
handle = SsdTensorHandle.from_tensor(tensor)
self.tensors[tensor_offset] = handle
handle.set_file_params(self.filename, tensor_offset)
self.offset += num_elems
def __init__(self, params: Sequence[torch.nn.Parameter], filename: str, requires_grad: bool = True):
"""Initialize the _param_numels and _param_shapes lists."""
self._param_numels = [p.numel() for p in params]
total_numels = sum(self._param_numels)
assert (
self.numel() <= total_numels
), f"Something wrong with __new__ method, {self.numel()} vs. {sum(self._param_numels)}"
self._param_shapes = [p.size() for p in params]
return handle
# These are set by FPW class below, not by this class itself.
self._param_infos: List[Tuple[str, torch.nn.Module, str]] = []
self._shared_param_infos: List[Tuple[str, str, torch.nn.Module, str, torch.nn.Module, str]] = []
def insert(self, tensor: torch.Tensor) -> SsdTensorHandle:
"""Insert a new tensor by allocating memory and creating a corresponding handle."""
assert self.storage_state == StorageState.ON_CPU, self.storage_state
# For the non sharded case, the tensor will not be flattened
tensor = tensor.reshape(-1)
assert self.buffer.dtype == tensor.dtype
handle = self.allocate(tensor.numel())
handle.get_tensor().copy_(tensor)
return handle
super(SsdFlatParameter, self).__init__(shape=(total_numels,), dtype=params[0].dtype, requires_grad=requires_grad) # type: ignore
def can_alloc(self, num_elems: int) -> bool:
"""Verify that you can allocate a tensor within the bounds
of the larger SsdBuffer memory buffer."""
assert self.storage_state == StorageState.ON_CPU, self.storage_state
return (self.offset + num_elems) <= self.buffer.numel()
def get_tensors(self) -> List[SsdTensorHandle]:
"""Returns the list of tensor handles in SsdBuffer."""
return [t for t in self.tensors.values()]
def to_disk(self) -> None:
"""Writes all tensors backed by handles to disk."""
if self.storage_state == StorageState.ON_DISK:
return
assert self.storage_state == StorageState.ON_CPU, self.storage_state
# We use `narrow` so that we write valid tensors that have been allocated
# as opposed to the entire SSD buffer.
valid_data = self.buffer.narrow(0, 0, self.offset)
write(valid_data, self.filename)
# Remove all Tensor references
for offset, t in self.tensors.items():
t.point_to_file(self.filename, offset)
# TODO(anj-s): Setting this to None does not result in GC picking
# this reference up.
self.buffer = torch.empty((1))
self.storage_state = StorageState.ON_DISK
def from_disk(self, num_elems: int, dtype: torch.dtype = torch.float32) -> None:
"""Reads all tensors backed by handles into memory."""
if self.storage_state == StorageState.ON_CPU:
return
assert self.storage_state == StorageState.ON_DISK, self.storage_state
if num_elems < self.offset:
raise RuntimeError(
f"Attempted to load from file ssdbuffer of size: {self.offset} into a buffer that is of size: {num_elems}"
tensor = torch.cat(
[p.detach().reshape(-1) if isinstance(p, torch.nn.Parameter) else p.reshape(-1) for p in params], 0
)
self.buffer = torch.empty((num_elems,), dtype=dtype)
valid_data = self.buffer.narrow(0, 0, self.offset)
read(valid_data, self.filename)
for offset, t in self.tensors.items():
t.point_to_tensor(self.buffer.narrow(0, t.offset, t._numel))
tensor.requires_grad = requires_grad
self.set_file_params(filename, 0)
self.point_to_tensor(tensor)
self.storage_state = StorageState.ON_CPU
def get_param_views(self, external_data: Optional[torch.Tensor] = None) -> Iterator[torch.Tensor]:
"""Return a generator of views that map to the original parameters."""
# Note, self.data could be sharded, so its numel is <= to the sum.
"""
assert self.data.numel() <= sum(
self._param_numels
), f"Incorrect internal state {self.data.numel()} vs. {sum(self._param_numels)}"
"""
if external_data:
if external_data.numel() != sum(self._param_numels):
raise ValueError(
f"Incorrect numel of supplied data: got {external_data.numel()} but expected {sum(self._param_numels)}"
)
return (t.view(s) for (t, s) in zip(external_data.split(self._param_numels), self._param_shapes))
else:
return (t.view(s) for (t, s) in zip(self.split(self._param_numels), self._param_shapes))
def metadata(self) -> Tuple[List[str], List[torch.Size], List[int]]:
"""Return tuple of (names, shapes, numels) metadata for this flat parameter."""
names = [".".join([m, n]) if m else n for (m, _, n) in self._param_infos]
return names, self._param_shapes, self._param_numels
def __setstate__(self, state: Tuple[Any, Any, Any, Any]) -> None:
"""Use by pickle to set the internal states."""
(self._param_numels, self._param_shapes, self._param_infos, self._shared_param_infos) = state
assert self.numel() <= sum(
self._param_numels
), f"Incorrect pickling {self.numel()} vs. {sum(self._param_numels)}"
def __reduce_ex__(self, proto: int) -> Tuple[Any, Any, Any]:
"""Support pickling between ranks."""
return (
SsdFlatParameter, # Callable
# Args to the callable above
([self.data], self.filename, self.requires_grad),
# Args to __setstate__
(self._param_numels, self._param_shapes, self._param_infos, self._shared_param_infos),
)
......@@ -65,6 +65,7 @@ else:
try:
import fairscale.experimental.nn.ssd_offload as ssd_offload
from fairscale.experimental.nn.ssd_offload import SsdFlatParameter
import_ssd_offload = True
except ImportError:
......@@ -109,9 +110,9 @@ class OffloadConfig:
"""Class for specifying all arguments related to offloading parameters."""
# Offload type: currently only supports: "ssd_offload"
offload_type: str = None
offload_type: Optional[str] = None
# Path to the directory for storing parameters offloaded to disk.
ssd_filepath_dir: str = None
dir: Optional[str] = None
class FullyShardedDataParallel(nn.Module):
......@@ -300,7 +301,7 @@ class FullyShardedDataParallel(nn.Module):
force_input_to_fp32: bool = False,
verbose: bool = False,
cpu_offload: bool = False,
offload_config: OffloadConfig = None,
offload_config: Optional[OffloadConfig] = None,
):
init_start = time.time()
super().__init__()
......@@ -335,6 +336,9 @@ class FullyShardedDataParallel(nn.Module):
if self.fp32_reduce_scatter and not self.mixed_precision:
raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
if self.ssd_offload and not self.flatten_parameters:
raise ValueError(f"offload type: '{offload_config.offload_type}' requires flatten_parameters=True")
# skip validation if the process group was created above
if process_group:
validate_process_group(self.compute_device, self.process_group)
......@@ -358,13 +362,11 @@ class FullyShardedDataParallel(nn.Module):
# TODO(anj): Should we conditionally do this only if we have params?
# TODO(anj): Figure out if we can allocate the buffer during sharding.
self.buffer_size = sum(p.numel() for p in params)
self.ssd_directory = tempfile.gettempdir()
if self.ssd_offload:
assert import_ssd_offload, "We need to import ssd_offload.py to enable the `ssd_offload` feature."
self.ssd_buffer_filepath_dir = (
offload_config.ssd_filepath_dir if offload_config.ssd_filepath_dir else tempfile.gettempdir()
)
self.ssd_buffer_filename = tempfile.mkstemp(dir=self.ssd_buffer_filepath_dir)
self.ssd_buffer = ssd_offload.SsdBuffer(self.buffer_size, self.ssd_buffer_filename[1])
if offload_config and offload_config.dir:
self.ssd_directory = offload_config.dir
self.move_grads_to_cpu = True
self.move_params_to_cpu = True
......@@ -379,7 +381,9 @@ class FullyShardedDataParallel(nn.Module):
param_name_groups = [param_names]
del param_names
self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(module, param_list=to_be_flatten_params)
self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(
module, param_list=to_be_flatten_params, ssd_offload=self.ssd_offload, ssd_directory=self.ssd_directory
)
del module # free original module in case it helps garbage collection
# Now, in this FSDP wrapper class, we keep a list of to-be-flatten and not-to-be-flatten
......@@ -675,15 +679,7 @@ class FullyShardedDataParallel(nn.Module):
p._orig_size = p.data.size()
if not p._is_sharded:
if self.ssd_offload:
# Insert tensor into the SSD buffer and free parameter storage.
p._is_sharded = False
self.numel_padded_per_param.append(0)
p._shard_size = p.data.size() # type: ignore
p._handle = self.ssd_buffer.insert(p.data) # type: ignore
free_storage_(p.data)
continue
else:
if not self.ssd_offload:
p._is_sharded = False
self.numel_padded_per_param.append(0)
continue
......@@ -691,14 +687,11 @@ class FullyShardedDataParallel(nn.Module):
# Replace p.data with the relevant shard.
if self.ssd_offload:
orig_data = p.data
p.data, num_padded = self._get_shard(p.data)
p._shard_size = p.data.size() # type: ignore
# Insert tensor into the SSD buffer and free parameter storage.
p._handle = self.ssd_buffer.insert(p.data) # type: ignore
del orig_data
assert isinstance(p, SsdFlatParameter)
sharded_tensor, num_padded = self._get_shard(p.data)
p.point_to_resized_tensor(sharded_tensor)
self.numel_padded_per_param.append(num_padded)
free_storage_(p.data)
p.to_file()
else:
orig_data = p.data
p.data, num_padded = self._get_shard(p.data)
......@@ -707,10 +700,6 @@ class FullyShardedDataParallel(nn.Module):
assert len(self.numel_padded_per_param) == len(self.params)
# Move SSD buffer to disk.
if self.ssd_offload:
self.ssd_buffer.to_disk()
def _get_shard(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, int]:
"""Return the local shard of a full tensor."""
# Shard using torch.chunk to match all-gather/reduce-scatter.
......@@ -791,11 +780,6 @@ class FullyShardedDataParallel(nn.Module):
"""Returns an iterator over the module parameters, yielding all the parameters
part of the model.
"""
# TODO(anj): Use `copy_into_tensor` in order to provide a copy of the
# parameters and not the actual parameters. Ideally we don't users to operate on
# actual params.
if self.ssd_offload:
self.ssd_buffer.from_disk(self.buffer_size)
return super().parameters(recurse=recurse)
......@@ -810,12 +794,6 @@ class FullyShardedDataParallel(nn.Module):
If you want the full param to be returned, you should call this function
under a `summon_full_params` context when using flattened or original params.
"""
# TODO(anj): Use `copy_into_tensor` in order to provide a copy of the
# parameters and not the actual parameters. Ideally we don't users to operate on
# actual params.
if self.ssd_offload:
self.ssd_buffer.from_disk(self.buffer_size)
named_param = super().named_parameters(*args, **kwargs)
for name, param in named_param:
if (
......@@ -923,10 +901,9 @@ class FullyShardedDataParallel(nn.Module):
def _move_params_to_memory(self) -> None:
"""Move params from disk to CPU."""
self.ssd_buffer.from_disk(self.buffer_size)
for p, handle in zip(self.params, self.ssd_buffer.get_tensors()):
p.data = handle.get_tensor().view(p._shard_size) # type: ignore
for p in self.params:
assert isinstance(p, SsdFlatParameter)
p.to_tensor()
def _load_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
......@@ -1072,14 +1049,13 @@ class FullyShardedDataParallel(nn.Module):
if safe_to_free:
free_storage_(full_tensor)
self.has_full_params = False
self._use_fp32_param_shard()
if self.ssd_offload:
# Store tensors in the SSD buffer and free param storage.
for p in self.params:
p._shard_size = p.data.size() # type: ignore
p._handle = self.ssd_buffer.insert(p.data) # type: ignore
free_storage_(p.data)
self.ssd_buffer.to_disk()
assert isinstance(p, SsdFlatParameter)
p.to_file()
else:
self._use_fp32_param_shard()
self.training_state = TrainingState.IDLE
def _reset_lazy_init(self) -> None:
......@@ -1153,6 +1129,11 @@ class FullyShardedDataParallel(nn.Module):
return
# A single shard of the parameters in full precision.
# TODO(another-pjohnson) - I believe this will cause memory leakage with ssd
# p.data returns a pointer to a handle, and that handle has it's
# ref count incremented by p._fp32_shard. So this tensor will
# never be freed even if we do p.to_disk(). investigate after
# PR #887 is merged
p._fp32_shard = p.data
if self.mixed_precision:
......@@ -1160,12 +1141,12 @@ class FullyShardedDataParallel(nn.Module):
if self.move_params_to_cpu:
assert p._fp32_shard.device == torch.device("cpu")
# We don't pin memory when using ssd_offload since that results in OOM when
# the memory requirements of a model are larger than host memory.
if not self.ssd_offload:
# If we plan to keep the FP32 parameters on CPU, then pinning
# memory allows us to later use non-blocking transfers when moving
# the FP32 param shard to compute_device.
if not self.ssd_offload:
# We don't pin memory when using ssd_offload since that results in OOM when
# the memory requirements of a model are larger than host memory.
p._fp32_shard = p._fp32_shard.pin_memory()
p.data = p._fp32_shard
......@@ -1206,9 +1187,12 @@ class FullyShardedDataParallel(nn.Module):
# shard in pinned memory so that we can do a non-blocking transfer.
# This is only needed during training and not evaluation.
if self.ssd_offload:
# We don't pin memory when using ssd_offload since that results in OOM when
# the memory requirements of a model are larger than host memory.
p._cpu_grad = torch.zeros_like(p.data, device="cpu")
assert isinstance(p, SsdFlatParameter)
# Gradients also need to be offloaded to SSD otherwise it can result in
# OOMs when the memory requirements of a model are larger than host memory.
p._cpu_grad = ssd_offload.SsdTensorHandle.from_tensor(torch.zeros_like(p.data, device="cpu"))
p._cpu_grad.set_file_params(p.filename + "_grad", 0)
p._cpu_grad.to_file()
else:
p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory()
......@@ -1298,9 +1282,6 @@ class FullyShardedDataParallel(nn.Module):
self._streams["all_gather"].wait_stream(torch.cuda.current_stream())
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
if self.ssd_offload:
self._move_params_to_memory()
self._lazy_init()
# Start of a forward pass.
......@@ -1365,7 +1346,9 @@ class FullyShardedDataParallel(nn.Module):
@torch.no_grad()
def _free_ssd_offload(self) -> None:
if self.ssd_offload:
self.ssd_buffer.to_disk()
for p in self.params:
assert isinstance(p, SsdFlatParameter)
p.to_file(permit_when_tensor_none=True)
def _register_pre_backward_hooks(self, outputs: Any) -> Any:
"""Register pre-backward hook to run before the wrapped module's
......@@ -1405,7 +1388,7 @@ class FullyShardedDataParallel(nn.Module):
# Note, both ``self._rebuild_full_params`` and ``self._use_full_params`` are
# idempotent. So in case they are called unnecessarily, they don't incur much
# overhead.
if self.ssd_offload or self.reshard_after_forward:
if self.reshard_after_forward:
self._rebuild_full_params()
else:
self._use_full_params()
......@@ -1730,7 +1713,7 @@ class FullyShardedDataParallel(nn.Module):
for m in self.modules(): # includes self
if isinstance(m, FullyShardedDataParallel):
_finalize_parameters(m)
self._free_ssd_offload()
m._free_ssd_offload()
m._pre_backward_hook_has_run = False
if any(p.requires_grad for p in m.parameters()):
# Check if the module has params and if any of them has
......@@ -1808,12 +1791,9 @@ class FullyShardedDataParallel(nn.Module):
p.data = p.data[: p._orig_size.numel()].view(p._orig_size)
if self.ssd_offload:
self.ssd_buffer.from_disk(self.buffer_size)
# The params are on disk and need to be moved to the CPU.
for p, handle in zip(self.params, self.ssd_buffer.get_tensors()):
p._fp32_shard = handle.get_tensor().view(p._shard_size) # type: ignore
p.data = p._fp32_shard
for p in self.params:
assert isinstance(p, SsdFlatParameter)
p.to_tensor()
self.has_full_params = False
......
......@@ -8,6 +8,7 @@
from contextlib import contextmanager
from itertools import chain
import tempfile
import typing
from typing import (
TYPE_CHECKING,
......@@ -30,6 +31,7 @@ import torch
from torch import Tensor
import torch.nn as nn
from fairscale.experimental.nn.ssd_offload import SsdFlatParameter
from fairscale.utils.state_dict import replace_by_prefix_
if TYPE_CHECKING:
......@@ -114,6 +116,7 @@ class FlatParameter(nn.Parameter):
# Static types.
FlatTypes = Union[FlatParameter, SsdFlatParameter]
ParamGroups = Optional[Union[List[List[nn.Parameter]], List[nn.Parameter]]]
......@@ -147,7 +150,14 @@ class FlattenParamsWrapper(nn.Module):
prefix will be added to those names.
"""
def __init__(self, module: nn.Module, param_list: ParamGroups = None, flat_param_names: Optional[List[str]] = None):
def __init__(
self,
module: nn.Module,
param_list: ParamGroups = None,
flat_param_names: Optional[List[str]] = None,
ssd_offload: bool = False,
ssd_directory: str = "",
):
super().__init__()
self._fpw_module = module
self.is_flattened = False
......@@ -195,7 +205,7 @@ class FlattenParamsWrapper(nn.Module):
# support.
raise ValueError(f"Incorrect param groups {len(overall_param_set)} vs {self.num_param_managed}")
self.flat_params: List[FlatParameter] = []
self.flat_params: List[FlatTypes] = []
# Prepare flat param names.
if flat_param_names is None:
......@@ -205,10 +215,16 @@ class FlattenParamsWrapper(nn.Module):
if len(flat_param_names) != len(set(flat_param_names)):
raise ValueError("Each flat param must be given a unique name")
self.flat_param_names = [f"flat_param_{n}" for n in flat_param_names]
flat_param: Optional[FlatTypes] = None
# Init all flat_params.
for new_p_set in self._param_sets:
params, param_infos, shared_param_infos = self._init_flatten_params(new_p_set)
if ssd_offload:
assert ssd_directory != ""
(handle, fname) = tempfile.mkstemp(dir=ssd_directory, suffix="ssd_buf_param")
flat_param = SsdFlatParameter(params=params, filename=fname, requires_grad=params[0].requires_grad)
else:
flat_param = FlatParameter(params, params[0].requires_grad)
flat_param._param_infos = param_infos
flat_param._shared_param_infos = shared_param_infos
......@@ -288,7 +304,7 @@ class FlattenParamsWrapper(nn.Module):
def _shared_param_infos(self) -> Iterator[Tuple[str, str, nn.Module, str, nn.Module, str]]:
return chain(*[p._shared_param_infos for p in self.flat_params])
def _flatten_params(self, flat_params: List[FlatParameter]) -> None:
def _flatten_params(self, flat_params: List[FlatTypes]) -> None:
"""Flatten the managed parameters and replaced the original
attributes with views to the flat params.
"""
......
from typing import Any, BinaryIO, Union
import os
import pickle
from typing import Any, BinaryIO, Callable, IO, Union
def save(obj, f: Union[str, BinaryIO]) -> None: ...
DEFAULT_PROTOCOL: int = 2
def save(obj, f: Union[str, os.PathLike, BinaryIO, IO[bytes]],
pickle_module: Any=pickle, pickle_protocol: int=DEFAULT_PROTOCOL, _use_new_zipfile_serialization: bool=True) -> None: ...
def load(f: Union[str, BinaryIO], map_location) -> Any: ...
......@@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
"""
Testing SsdBuffer and SsdTensorHandle modules.
Testing SsdFlatParameter and SsdTensorHandle modules.
"""
import tempfile
......@@ -38,6 +38,8 @@ def test_write_read():
def test_ssd_handle_dispatch_fwd():
_init()
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((128))
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
......@@ -54,6 +56,8 @@ def test_ssd_handle_dispatch_fwd():
def test_ssd_handle_dispatch_bwd():
_init()
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4), requires_grad=True)
orig_copy = orig_tensor.clone().detach().requires_grad_(True)
......@@ -68,98 +72,89 @@ def test_ssd_handle_dispatch_bwd():
y1.sum().backward()
y2.sum().backward()
# TODO: PJ/ASenable assert once Tensor._make_subclass can properly define the tensor's shape
# assert torch.equal(ssd_handle.grad, orig_copy.grad)
assert torch.equal(ssd_handle.grad, orig_copy.grad)
def test_ssd_buffer_basic():
def test_ssd_handle_train_simple():
_init()
with tempfile.NamedTemporaryFile() as f:
refa_tensor = torch.rand((128), dtype=torch.float32)
refb_tensor = torch.rand((128), dtype=torch.float32)
refc_tensor = torch.rand((128), dtype=torch.float32)
ssd_buf = so.SsdBuffer(1024, f.name)
hdl_a = ssd_buf.insert(refa_tensor)
hdl_b = ssd_buf.insert(refb_tensor)
hdl_c = ssd_buf.insert(refc_tensor)
assert hdl_a.is_available()
assert hdl_b.is_available()
assert hdl_c.is_available()
assert torch.equal(refa_tensor, hdl_a.get_tensor())
assert torch.equal(refb_tensor, hdl_b.get_tensor())
assert torch.equal(refc_tensor, hdl_c.get_tensor())
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4), requires_grad=True)
tensors = ssd_buf.get_tensors()
assert hdl_a is tensors[0]
assert hdl_b is tensors[1]
assert hdl_c is tensors[2]
with torch.no_grad():
orig_copy = torch.empty_like(orig_tensor)
orig_copy.copy_(orig_tensor)
orig_copy.requires_grad = True
# test read_into_tensor when handle.is_available()
b_tensor_copy1 = torch.empty_like(refb_tensor)
hdl_b.copy_into_tensor(b_tensor_copy1)
assert torch.equal(refb_tensor, b_tensor_copy1)
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
ssd_handle.set_file_params(f.name, 0)
ssd_handle.to_file(release_tensor_after_write=True)
# remove references so memory will be cleaned up
buffer = None
assert torch.equal(ssd_handle.to_tensor(), orig_tensor)
optimizer_ssd = torch.optim.SGD([ssd_handle], lr=0.1)
optimizer_orig = torch.optim.SGD([orig_copy], lr=0.1)
ssd_buf.to_disk()
y1 = ssd_handle + 1
optimizer_ssd.zero_grad()
y1.sum().backward()
optimizer_ssd.step()
assert hdl_a.filename == f.name
assert hdl_b.filename == f.name
assert hdl_c.filename == f.name
y2 = orig_copy + 1
optimizer_orig.zero_grad()
y2.sum().backward()
optimizer_orig.step()
assert hdl_a.offset == 0
assert hdl_b.offset == 128
assert hdl_c.offset == 256
# make sure we are using the file version not the cached tensor
ssd_handle.point_to_file(f.name, 0)
assert torch.equal(ssd_handle.to_tensor(), orig_copy)
assert not hdl_a.is_available()
assert not hdl_b.is_available()
assert not hdl_c.is_available()
# test read_into_tensor when !handle.is_available()
b_tensor_copy2 = torch.empty_like(refb_tensor)
hdl_b.copy_into_tensor(b_tensor_copy2)
assert torch.equal(refb_tensor, b_tensor_copy2)
def test_ssd_flat_param_train_simple():
_init()
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4))
ssd_buf.from_disk(384)
with torch.no_grad():
orig_copy = torch.empty_like(orig_tensor)
orig_copy.copy_(orig_tensor)
param = torch.nn.Parameter(orig_copy)
assert hdl_a.is_available()
assert hdl_b.is_available()
assert hdl_c.is_available()
ssd_flat_param = so.SsdFlatParameter([param], f.name, True)
assert torch.equal(refa_tensor, hdl_a.get_tensor())
assert torch.equal(refb_tensor, hdl_b.get_tensor())
assert torch.equal(refc_tensor, hdl_c.get_tensor())
assert torch.equal(list(ssd_flat_param.get_param_views())[0], orig_tensor)
optimizer_ssd = torch.optim.SGD([ssd_flat_param], lr=0.1)
optimizer_orig = torch.optim.SGD([param], lr=0.1)
y1 = ssd_flat_param + 1
optimizer_ssd.zero_grad()
y1.sum().backward()
optimizer_ssd.step()
def test_ssd_buffer_too_small_from_disk():
_init()
with tempfile.NamedTemporaryFile() as f:
refa_tensor = torch.rand((128), dtype=torch.float32)
ssd_buf = so.SsdBuffer(128, f.name)
hdl_a = ssd_buf.insert(refa_tensor)
ssd_buf.to_disk()
y2 = param + 1
optimizer_orig.zero_grad()
y2.sum().backward()
optimizer_orig.step()
with pytest.raises(RuntimeError):
ssd_buf.from_disk(127)
# make sure we are using the file version not the cached tensor
ssd_flat_param.point_to_file(f.name, 0)
assert torch.equal(list(ssd_flat_param.get_param_views())[0], param)
def test_ssd_buffer_null_buffer():
def test_ssd_flat_parameter_basic():
_init()
with tempfile.NamedTemporaryFile() as f:
refa_tensor = torch.rand((128), dtype=torch.float32)
ssd_buf = so.SsdBuffer(128, f.name)
hdl_a = ssd_buf.insert(refa_tensor)
ssd_buf.to_disk()
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32))
ssd_flat_param = so.SsdFlatParameter([refa_param, refb_param, refc_param], f.name, False)
with pytest.raises(AssertionError):
hdl_a = ssd_buf.insert(refa_tensor)
param_views = list(ssd_flat_param.get_param_views())
with pytest.raises(AssertionError):
ssd_buf.can_alloc(128)
assert refa_param.shape == param_views[0].shape
assert refb_param.shape == param_views[1].shape
assert refc_param.shape == param_views[2].shape
with pytest.raises(AssertionError):
hdl = ssd_buf.allocate(128)
assert torch.equal(refa_param, param_views[0])
assert torch.equal(refb_param, param_views[1])
assert torch.equal(refc_param, param_views[2])
ssd_flat_param.to_file()
......@@ -102,6 +102,8 @@ class DistributedTest(unittest.TestCase):
# Confirm we get the same behavior using FullyShardedDataParallel.
if config.get("ssd_offload", False):
config["offload_config"] = OffloadConfig(offload_type="ssd_offload")
# ssd offload only supports flatten_params ATM
config["flatten_parameters"] = True
del config["ssd_offload"]
model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
......@@ -121,7 +123,7 @@ class DistributedTest(unittest.TestCase):
assert isinstance(metadata, dict)
keys = ["reshard_after_forward", "mixed_precision", "flatten_parameters", "nested_wrapping"]
keys = ["reshard_after_forward", "mixed_precision", "nested_wrapping"]
CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([True, False], repeat=len(keys))]
......@@ -152,7 +154,7 @@ class TestSsdMemory(DistributedTest):
time_keeper.print_time("CPU_MODEL", 1.0)
with tempfile.TemporaryDirectory() as current_tempdir:
config["offload_config"] = OffloadConfig(offload_type="ssd_offload", ssd_filepath_dir=current_tempdir)
config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
model = FullyShardedDataParallel(model, **config)
time_keeper.print_time("FSDP_MODEL", 1.0)
......@@ -223,7 +225,9 @@ class TestModuleProperties(DistributedTest):
with tempfile.TemporaryDirectory() as current_tempdir:
if config["ssd_offload"]:
config["offload_config"] = OffloadConfig(offload_type="ssd_offload", ssd_filepath_dir=current_tempdir)
config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
# ssd offload only supports flatten_params ATM
config["flatten_parameters"] = True
del config["ssd_offload"]
model = FullyShardedDataParallel(before_wrap_model, **config)
......@@ -260,6 +264,47 @@ class TestSsdLoading(DistributedTest):
def test_transformer_parameterized(self, config):
spawn_and_init(functools.partial(self._test_identical_outputs_eval, TransformerWithSharedParams, config))
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_ssd_offloading_train_flatten_params_wrapper(self, config):
test_fn = functools.partial(self._test_ssd_offloading_train_flatten_params_wrapper, config=config)
spawn_and_init(test_fn)
@classmethod
def _test_ssd_offloading_train_flatten_params_wrapper(self, rank, group, config):
SIZE = 16 * 16
model = SimpleLinear(group, input_size=SIZE, output_size=SIZE, layers=4)
with tempfile.TemporaryDirectory() as current_tempdir:
config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
config["flatten_parameters"] = True
nested_wrapping = config["nested_wrapping"]
del config["nested_wrapping"]
if nested_wrapping:
model = FullyShardedDataParallel(
NestedWrappedModule(group, wrap_everything=True, wrapper_config=config)
)
else:
model = FullyShardedDataParallel(model, **config)
model_device = torch.device("cuda")
model.train()
optim = torch.optim.SGD(model.parameters(), lr=4, momentum=0.9)
# Inputs always cuda regardless of move_grads_cpu, or model.device
with torch.cuda.amp.autocast(enabled=config.get("mixed_precision", False)):
for i in range(10):
optim.zero_grad()
input = model.get_input(torch.device("cuda"))
output = model(*input)
loss = model.module.get_loss(input, output).to(model_device)
assert loss.dtype == torch.float32
model.module.run_backward(loss)
optim.step()
if isinstance(model, FullyShardedDataParallel):
model.assert_state(TrainingState.IDLE)
@classmethod
def _test_ssd_offload_eval(self, rank, group, config):
model = TransformerWithSharedParams(group)
......@@ -267,9 +312,10 @@ class TestSsdLoading(DistributedTest):
nested_wrapping = config["nested_wrapping"]
del config["nested_wrapping"]
config["flatten_parameters"] = True
with tempfile.TemporaryDirectory() as current_tempdir:
config["offload_config"] = OffloadConfig(offload_type="ssd_offload", ssd_filepath_dir=current_tempdir)
config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
if nested_wrapping:
model = FullyShardedDataParallel(
NestedWrappedModule(group, wrap_everything=True, wrapper_config=config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment