Unverified Commit c5e471bc authored by Paul Johnson's avatar Paul Johnson Committed by GitHub
Browse files

Enabling ssd_offload training basic tests. (#887)

* Enabling ssd_offload training and test via tests/nn/data_parallel/test_fsdp_offload.py.
* Removed unused classes: SsdBuffer, SsdTensorHandleView, SsdParameter, SsdTensor
* Enhance test coverage of test_ssd_offloading_train_flatten_params_wrapper
* Modifications from PR #887 review comments.
* Update Changelog
parent 541bb8c9
...@@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed ### Changed
- Fixed a corner case of FSDP init order and losing one of the flags [#880] - Fixed a corner case of FSDP init order and losing one of the flags [#880]
- FSDP: Adding basic training support for SSD Offload, it now only supports flattened parameters. Renamed OffloadConfig.ssd_filepath_dir to more generic OffloadConfig.dir. SSD Offload remains an experimental feature. [#887]
## [0.4.3] - 2021-11-18 ## [0.4.3] - 2021-11-18
......
...@@ -9,7 +9,7 @@ from enum import Enum, auto ...@@ -9,7 +9,7 @@ from enum import Enum, auto
from functools import reduce from functools import reduce
import io import io
import os import os
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Iterator, List, Optional, Sequence, Tuple, Type
import numpy as np import numpy as np
import torch import torch
...@@ -131,7 +131,7 @@ class SsdTensorHandle(torch.Tensor): ...@@ -131,7 +131,7 @@ class SsdTensorHandle(torch.Tensor):
return handle return handle
@classmethod @classmethod
def from_tensor(cls, tensor: torch.Tensor) -> SsdTensorHandle: def from_tensor(cls: Type[SsdTensorHandle], tensor: torch.Tensor) -> SsdTensorHandle:
"""Returns a new SsdTensorHandle from a tensor.""" """Returns a new SsdTensorHandle from a tensor."""
handle = cls(shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad) handle = cls(shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad)
handle.tensor = tensor handle.tensor = tensor
...@@ -159,6 +159,13 @@ class SsdTensorHandle(torch.Tensor): ...@@ -159,6 +159,13 @@ class SsdTensorHandle(torch.Tensor):
assert self._dtype == tensor.dtype assert self._dtype == tensor.dtype
self.tensor = tensor self.tensor = tensor
# if resizing a handle that is part of an ssd buffer, care must be taken that the new size
# doesn't conflict with adjacent handles!
def point_to_resized_tensor(self, tensor: torch.Tensor) -> None:
assert self._dtype == tensor.dtype
self._shape = tensor.shape
self.tensor = tensor
def to_tensor(self) -> torch.Tensor: def to_tensor(self) -> torch.Tensor:
"""Returns the tensor represented by the SsdTensorHandle object. """Returns the tensor represented by the SsdTensorHandle object.
...@@ -173,9 +180,11 @@ class SsdTensorHandle(torch.Tensor): ...@@ -173,9 +180,11 @@ class SsdTensorHandle(torch.Tensor):
self.storage_state = StorageState.ON_CPU self.storage_state = StorageState.ON_CPU
return self.tensor return self.tensor
def to_file(self, release_tensor_after_write: bool = True) -> None: def to_file(self, permit_when_tensor_none: bool = False, release_tensor_after_write: bool = True) -> None:
"""Saves the tensor to disk and releases memory if specified.""" """Saves the tensor to disk and releases memory if specified."""
assert self.tensor is not None assert self.tensor is not None or permit_when_tensor_none
if self.tensor is not None:
write(self.tensor, self.filename, self.offset * self.tensor.element_size()) write(self.tensor, self.filename, self.offset * self.tensor.element_size())
if release_tensor_after_write: if release_tensor_after_write:
self.tensor = None self.tensor = None
...@@ -229,92 +238,96 @@ class SsdTensorHandle(torch.Tensor): ...@@ -229,92 +238,96 @@ class SsdTensorHandle(torch.Tensor):
return r return r
class SsdBuffer: class SsdFlatParameter(torch.nn.Parameter, SsdTensorHandle):
"""A parameter that is initialized from a list of parameters and can be
turned into a list of views as needed.
""" """
The SsdBuffer represents a single buffer containing a list of tensors. Each of the
tensors are represented by a `SsdTensorHandle`.
Args:
num_elems (int): Dictates the size of the 1-D tensor.
dtype (torch.dtype): Dtype of the buffer.
"""
def __init__(self, num_elems: int, filename: str, dtype: torch.dtype = torch.float32) -> None:
self.buffer: torch.Tensor = torch.empty((num_elems,), dtype=dtype)
self.filename = filename
self.offset = 0
self.tensors: Dict[int, SsdTensorHandle] = {}
self.storage_state = StorageState.ON_CPU
def allocate(self, num_elems: int) -> SsdTensorHandle:
"""Allocates a new tensor handle of size num_elems."""
assert num_elems > 0
assert self.storage_state == StorageState.ON_CPU, self.storage_state
assert self.can_alloc(num_elems)
tensor = self.buffer.narrow(0, self.offset, num_elems) def __new__(
cls, params: Sequence[torch.nn.Parameter], filename: str, requires_grad: bool = True
) -> "SsdFlatParameter":
"""Make an object using the parent's __new__ function."""
# A empty of non-list input doesn't make sense.
if not isinstance(params, (list, tuple)) or len(params) == 0:
raise ValueError("An non-empty list or tuple argument is needed")
# Normally, all items are Parameters. But during pickling, we will have a single
# Tensor as the input and later in __init__, the correct _param_numels and _param_shapes
# are set.
if not all(isinstance(p, (torch.nn.Parameter, torch.Tensor)) for p in params):
raise ValueError("List items need to be Parameter types")
# Flattening involves (1) making a tensor flat (i.e. single dimensional) and (2) making a module
# heirarchy flat (using a single tensor to replace a tree of tensors). Therefore,
# adding back nesting and heirarchy is counter-productive. If nesting is encountered
# in the future, the reasonable thing to do is likely for the top level SsdFlatParameter to
# absorb the nested one and keep the result flat, free from hierarchy.
if any(isinstance(p, SsdFlatParameter) for p in params):
raise ValueError("Nesting SsdFlatParameter is not supported")
dtype = params[0].dtype
size = sum(p.numel() for p in params)
r = SsdTensorHandle._make_wrapper_subclass(cls, (size,), dtype=dtype, requires_grad=requires_grad) # type: ignore
return r
tensor_offset = self.offset def __init__(self, params: Sequence[torch.nn.Parameter], filename: str, requires_grad: bool = True):
handle = SsdTensorHandle.from_tensor(tensor) """Initialize the _param_numels and _param_shapes lists."""
self.tensors[tensor_offset] = handle self._param_numels = [p.numel() for p in params]
handle.set_file_params(self.filename, tensor_offset) total_numels = sum(self._param_numels)
self.offset += num_elems assert (
self.numel() <= total_numels
), f"Something wrong with __new__ method, {self.numel()} vs. {sum(self._param_numels)}"
self._param_shapes = [p.size() for p in params]
return handle # These are set by FPW class below, not by this class itself.
self._param_infos: List[Tuple[str, torch.nn.Module, str]] = []
self._shared_param_infos: List[Tuple[str, str, torch.nn.Module, str, torch.nn.Module, str]] = []
def insert(self, tensor: torch.Tensor) -> SsdTensorHandle: super(SsdFlatParameter, self).__init__(shape=(total_numels,), dtype=params[0].dtype, requires_grad=requires_grad) # type: ignore
"""Insert a new tensor by allocating memory and creating a corresponding handle."""
assert self.storage_state == StorageState.ON_CPU, self.storage_state
# For the non sharded case, the tensor will not be flattened
tensor = tensor.reshape(-1)
assert self.buffer.dtype == tensor.dtype
handle = self.allocate(tensor.numel())
handle.get_tensor().copy_(tensor)
return handle
def can_alloc(self, num_elems: int) -> bool: tensor = torch.cat(
"""Verify that you can allocate a tensor within the bounds [p.detach().reshape(-1) if isinstance(p, torch.nn.Parameter) else p.reshape(-1) for p in params], 0
of the larger SsdBuffer memory buffer."""
assert self.storage_state == StorageState.ON_CPU, self.storage_state
return (self.offset + num_elems) <= self.buffer.numel()
def get_tensors(self) -> List[SsdTensorHandle]:
"""Returns the list of tensor handles in SsdBuffer."""
return [t for t in self.tensors.values()]
def to_disk(self) -> None:
"""Writes all tensors backed by handles to disk."""
if self.storage_state == StorageState.ON_DISK:
return
assert self.storage_state == StorageState.ON_CPU, self.storage_state
# We use `narrow` so that we write valid tensors that have been allocated
# as opposed to the entire SSD buffer.
valid_data = self.buffer.narrow(0, 0, self.offset)
write(valid_data, self.filename)
# Remove all Tensor references
for offset, t in self.tensors.items():
t.point_to_file(self.filename, offset)
# TODO(anj-s): Setting this to None does not result in GC picking
# this reference up.
self.buffer = torch.empty((1))
self.storage_state = StorageState.ON_DISK
def from_disk(self, num_elems: int, dtype: torch.dtype = torch.float32) -> None:
"""Reads all tensors backed by handles into memory."""
if self.storage_state == StorageState.ON_CPU:
return
assert self.storage_state == StorageState.ON_DISK, self.storage_state
if num_elems < self.offset:
raise RuntimeError(
f"Attempted to load from file ssdbuffer of size: {self.offset} into a buffer that is of size: {num_elems}"
) )
self.buffer = torch.empty((num_elems,), dtype=dtype) tensor.requires_grad = requires_grad
valid_data = self.buffer.narrow(0, 0, self.offset) self.set_file_params(filename, 0)
read(valid_data, self.filename) self.point_to_tensor(tensor)
for offset, t in self.tensors.items():
t.point_to_tensor(self.buffer.narrow(0, t.offset, t._numel))
self.storage_state = StorageState.ON_CPU def get_param_views(self, external_data: Optional[torch.Tensor] = None) -> Iterator[torch.Tensor]:
"""Return a generator of views that map to the original parameters."""
# Note, self.data could be sharded, so its numel is <= to the sum.
"""
assert self.data.numel() <= sum(
self._param_numels
), f"Incorrect internal state {self.data.numel()} vs. {sum(self._param_numels)}"
"""
if external_data:
if external_data.numel() != sum(self._param_numels):
raise ValueError(
f"Incorrect numel of supplied data: got {external_data.numel()} but expected {sum(self._param_numels)}"
)
return (t.view(s) for (t, s) in zip(external_data.split(self._param_numels), self._param_shapes))
else:
return (t.view(s) for (t, s) in zip(self.split(self._param_numels), self._param_shapes))
def metadata(self) -> Tuple[List[str], List[torch.Size], List[int]]:
"""Return tuple of (names, shapes, numels) metadata for this flat parameter."""
names = [".".join([m, n]) if m else n for (m, _, n) in self._param_infos]
return names, self._param_shapes, self._param_numels
def __setstate__(self, state: Tuple[Any, Any, Any, Any]) -> None:
"""Use by pickle to set the internal states."""
(self._param_numels, self._param_shapes, self._param_infos, self._shared_param_infos) = state
assert self.numel() <= sum(
self._param_numels
), f"Incorrect pickling {self.numel()} vs. {sum(self._param_numels)}"
def __reduce_ex__(self, proto: int) -> Tuple[Any, Any, Any]:
"""Support pickling between ranks."""
return (
SsdFlatParameter, # Callable
# Args to the callable above
([self.data], self.filename, self.requires_grad),
# Args to __setstate__
(self._param_numels, self._param_shapes, self._param_infos, self._shared_param_infos),
)
...@@ -65,6 +65,7 @@ else: ...@@ -65,6 +65,7 @@ else:
try: try:
import fairscale.experimental.nn.ssd_offload as ssd_offload import fairscale.experimental.nn.ssd_offload as ssd_offload
from fairscale.experimental.nn.ssd_offload import SsdFlatParameter
import_ssd_offload = True import_ssd_offload = True
except ImportError: except ImportError:
...@@ -109,9 +110,9 @@ class OffloadConfig: ...@@ -109,9 +110,9 @@ class OffloadConfig:
"""Class for specifying all arguments related to offloading parameters.""" """Class for specifying all arguments related to offloading parameters."""
# Offload type: currently only supports: "ssd_offload" # Offload type: currently only supports: "ssd_offload"
offload_type: str = None offload_type: Optional[str] = None
# Path to the directory for storing parameters offloaded to disk. # Path to the directory for storing parameters offloaded to disk.
ssd_filepath_dir: str = None dir: Optional[str] = None
class FullyShardedDataParallel(nn.Module): class FullyShardedDataParallel(nn.Module):
...@@ -300,7 +301,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -300,7 +301,7 @@ class FullyShardedDataParallel(nn.Module):
force_input_to_fp32: bool = False, force_input_to_fp32: bool = False,
verbose: bool = False, verbose: bool = False,
cpu_offload: bool = False, cpu_offload: bool = False,
offload_config: OffloadConfig = None, offload_config: Optional[OffloadConfig] = None,
): ):
init_start = time.time() init_start = time.time()
super().__init__() super().__init__()
...@@ -335,6 +336,9 @@ class FullyShardedDataParallel(nn.Module): ...@@ -335,6 +336,9 @@ class FullyShardedDataParallel(nn.Module):
if self.fp32_reduce_scatter and not self.mixed_precision: if self.fp32_reduce_scatter and not self.mixed_precision:
raise ValueError("fp32_reduce_scatter requires mixed_precision=True") raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
if self.ssd_offload and not self.flatten_parameters:
raise ValueError(f"offload type: '{offload_config.offload_type}' requires flatten_parameters=True")
# skip validation if the process group was created above # skip validation if the process group was created above
if process_group: if process_group:
validate_process_group(self.compute_device, self.process_group) validate_process_group(self.compute_device, self.process_group)
...@@ -358,13 +362,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -358,13 +362,11 @@ class FullyShardedDataParallel(nn.Module):
# TODO(anj): Should we conditionally do this only if we have params? # TODO(anj): Should we conditionally do this only if we have params?
# TODO(anj): Figure out if we can allocate the buffer during sharding. # TODO(anj): Figure out if we can allocate the buffer during sharding.
self.buffer_size = sum(p.numel() for p in params) self.buffer_size = sum(p.numel() for p in params)
self.ssd_directory = tempfile.gettempdir()
if self.ssd_offload: if self.ssd_offload:
assert import_ssd_offload, "We need to import ssd_offload.py to enable the `ssd_offload` feature." assert import_ssd_offload, "We need to import ssd_offload.py to enable the `ssd_offload` feature."
self.ssd_buffer_filepath_dir = ( if offload_config and offload_config.dir:
offload_config.ssd_filepath_dir if offload_config.ssd_filepath_dir else tempfile.gettempdir() self.ssd_directory = offload_config.dir
)
self.ssd_buffer_filename = tempfile.mkstemp(dir=self.ssd_buffer_filepath_dir)
self.ssd_buffer = ssd_offload.SsdBuffer(self.buffer_size, self.ssd_buffer_filename[1])
self.move_grads_to_cpu = True self.move_grads_to_cpu = True
self.move_params_to_cpu = True self.move_params_to_cpu = True
...@@ -379,7 +381,9 @@ class FullyShardedDataParallel(nn.Module): ...@@ -379,7 +381,9 @@ class FullyShardedDataParallel(nn.Module):
param_name_groups = [param_names] param_name_groups = [param_names]
del param_names del param_names
self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(module, param_list=to_be_flatten_params) self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(
module, param_list=to_be_flatten_params, ssd_offload=self.ssd_offload, ssd_directory=self.ssd_directory
)
del module # free original module in case it helps garbage collection del module # free original module in case it helps garbage collection
# Now, in this FSDP wrapper class, we keep a list of to-be-flatten and not-to-be-flatten # Now, in this FSDP wrapper class, we keep a list of to-be-flatten and not-to-be-flatten
...@@ -675,15 +679,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -675,15 +679,7 @@ class FullyShardedDataParallel(nn.Module):
p._orig_size = p.data.size() p._orig_size = p.data.size()
if not p._is_sharded: if not p._is_sharded:
if self.ssd_offload: if not self.ssd_offload:
# Insert tensor into the SSD buffer and free parameter storage.
p._is_sharded = False
self.numel_padded_per_param.append(0)
p._shard_size = p.data.size() # type: ignore
p._handle = self.ssd_buffer.insert(p.data) # type: ignore
free_storage_(p.data)
continue
else:
p._is_sharded = False p._is_sharded = False
self.numel_padded_per_param.append(0) self.numel_padded_per_param.append(0)
continue continue
...@@ -691,14 +687,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -691,14 +687,11 @@ class FullyShardedDataParallel(nn.Module):
# Replace p.data with the relevant shard. # Replace p.data with the relevant shard.
if self.ssd_offload: if self.ssd_offload:
orig_data = p.data assert isinstance(p, SsdFlatParameter)
p.data, num_padded = self._get_shard(p.data) sharded_tensor, num_padded = self._get_shard(p.data)
p._shard_size = p.data.size() # type: ignore p.point_to_resized_tensor(sharded_tensor)
# Insert tensor into the SSD buffer and free parameter storage.
p._handle = self.ssd_buffer.insert(p.data) # type: ignore
del orig_data
self.numel_padded_per_param.append(num_padded) self.numel_padded_per_param.append(num_padded)
free_storage_(p.data) p.to_file()
else: else:
orig_data = p.data orig_data = p.data
p.data, num_padded = self._get_shard(p.data) p.data, num_padded = self._get_shard(p.data)
...@@ -707,10 +700,6 @@ class FullyShardedDataParallel(nn.Module): ...@@ -707,10 +700,6 @@ class FullyShardedDataParallel(nn.Module):
assert len(self.numel_padded_per_param) == len(self.params) assert len(self.numel_padded_per_param) == len(self.params)
# Move SSD buffer to disk.
if self.ssd_offload:
self.ssd_buffer.to_disk()
def _get_shard(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, int]: def _get_shard(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, int]:
"""Return the local shard of a full tensor.""" """Return the local shard of a full tensor."""
# Shard using torch.chunk to match all-gather/reduce-scatter. # Shard using torch.chunk to match all-gather/reduce-scatter.
...@@ -791,11 +780,6 @@ class FullyShardedDataParallel(nn.Module): ...@@ -791,11 +780,6 @@ class FullyShardedDataParallel(nn.Module):
"""Returns an iterator over the module parameters, yielding all the parameters """Returns an iterator over the module parameters, yielding all the parameters
part of the model. part of the model.
""" """
# TODO(anj): Use `copy_into_tensor` in order to provide a copy of the
# parameters and not the actual parameters. Ideally we don't users to operate on
# actual params.
if self.ssd_offload:
self.ssd_buffer.from_disk(self.buffer_size)
return super().parameters(recurse=recurse) return super().parameters(recurse=recurse)
...@@ -810,12 +794,6 @@ class FullyShardedDataParallel(nn.Module): ...@@ -810,12 +794,6 @@ class FullyShardedDataParallel(nn.Module):
If you want the full param to be returned, you should call this function If you want the full param to be returned, you should call this function
under a `summon_full_params` context when using flattened or original params. under a `summon_full_params` context when using flattened or original params.
""" """
# TODO(anj): Use `copy_into_tensor` in order to provide a copy of the
# parameters and not the actual parameters. Ideally we don't users to operate on
# actual params.
if self.ssd_offload:
self.ssd_buffer.from_disk(self.buffer_size)
named_param = super().named_parameters(*args, **kwargs) named_param = super().named_parameters(*args, **kwargs)
for name, param in named_param: for name, param in named_param:
if ( if (
...@@ -923,10 +901,9 @@ class FullyShardedDataParallel(nn.Module): ...@@ -923,10 +901,9 @@ class FullyShardedDataParallel(nn.Module):
def _move_params_to_memory(self) -> None: def _move_params_to_memory(self) -> None:
"""Move params from disk to CPU.""" """Move params from disk to CPU."""
self.ssd_buffer.from_disk(self.buffer_size) for p in self.params:
assert isinstance(p, SsdFlatParameter)
for p, handle in zip(self.params, self.ssd_buffer.get_tensors()): p.to_tensor()
p.data = handle.get_tensor().view(p._shard_size) # type: ignore
def _load_state_dict( def _load_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
...@@ -1072,14 +1049,13 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1072,14 +1049,13 @@ class FullyShardedDataParallel(nn.Module):
if safe_to_free: if safe_to_free:
free_storage_(full_tensor) free_storage_(full_tensor)
self.has_full_params = False self.has_full_params = False
self._use_fp32_param_shard()
if self.ssd_offload: if self.ssd_offload:
# Store tensors in the SSD buffer and free param storage. # Store tensors in the SSD buffer and free param storage.
for p in self.params: for p in self.params:
p._shard_size = p.data.size() # type: ignore assert isinstance(p, SsdFlatParameter)
p._handle = self.ssd_buffer.insert(p.data) # type: ignore p.to_file()
free_storage_(p.data) else:
self.ssd_buffer.to_disk() self._use_fp32_param_shard()
self.training_state = TrainingState.IDLE self.training_state = TrainingState.IDLE
def _reset_lazy_init(self) -> None: def _reset_lazy_init(self) -> None:
...@@ -1153,6 +1129,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1153,6 +1129,11 @@ class FullyShardedDataParallel(nn.Module):
return return
# A single shard of the parameters in full precision. # A single shard of the parameters in full precision.
# TODO(another-pjohnson) - I believe this will cause memory leakage with ssd
# p.data returns a pointer to a handle, and that handle has it's
# ref count incremented by p._fp32_shard. So this tensor will
# never be freed even if we do p.to_disk(). investigate after
# PR #887 is merged
p._fp32_shard = p.data p._fp32_shard = p.data
if self.mixed_precision: if self.mixed_precision:
...@@ -1160,12 +1141,12 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1160,12 +1141,12 @@ class FullyShardedDataParallel(nn.Module):
if self.move_params_to_cpu: if self.move_params_to_cpu:
assert p._fp32_shard.device == torch.device("cpu") assert p._fp32_shard.device == torch.device("cpu")
# We don't pin memory when using ssd_offload since that results in OOM when
# the memory requirements of a model are larger than host memory.
if not self.ssd_offload:
# If we plan to keep the FP32 parameters on CPU, then pinning # If we plan to keep the FP32 parameters on CPU, then pinning
# memory allows us to later use non-blocking transfers when moving # memory allows us to later use non-blocking transfers when moving
# the FP32 param shard to compute_device. # the FP32 param shard to compute_device.
if not self.ssd_offload:
# We don't pin memory when using ssd_offload since that results in OOM when
# the memory requirements of a model are larger than host memory.
p._fp32_shard = p._fp32_shard.pin_memory() p._fp32_shard = p._fp32_shard.pin_memory()
p.data = p._fp32_shard p.data = p._fp32_shard
...@@ -1206,9 +1187,12 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1206,9 +1187,12 @@ class FullyShardedDataParallel(nn.Module):
# shard in pinned memory so that we can do a non-blocking transfer. # shard in pinned memory so that we can do a non-blocking transfer.
# This is only needed during training and not evaluation. # This is only needed during training and not evaluation.
if self.ssd_offload: if self.ssd_offload:
# We don't pin memory when using ssd_offload since that results in OOM when assert isinstance(p, SsdFlatParameter)
# the memory requirements of a model are larger than host memory. # Gradients also need to be offloaded to SSD otherwise it can result in
p._cpu_grad = torch.zeros_like(p.data, device="cpu") # OOMs when the memory requirements of a model are larger than host memory.
p._cpu_grad = ssd_offload.SsdTensorHandle.from_tensor(torch.zeros_like(p.data, device="cpu"))
p._cpu_grad.set_file_params(p.filename + "_grad", 0)
p._cpu_grad.to_file()
else: else:
p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory() p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory()
...@@ -1298,9 +1282,6 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1298,9 +1282,6 @@ class FullyShardedDataParallel(nn.Module):
self._streams["all_gather"].wait_stream(torch.cuda.current_stream()) self._streams["all_gather"].wait_stream(torch.cuda.current_stream())
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
if self.ssd_offload:
self._move_params_to_memory()
self._lazy_init() self._lazy_init()
# Start of a forward pass. # Start of a forward pass.
...@@ -1365,7 +1346,9 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1365,7 +1346,9 @@ class FullyShardedDataParallel(nn.Module):
@torch.no_grad() @torch.no_grad()
def _free_ssd_offload(self) -> None: def _free_ssd_offload(self) -> None:
if self.ssd_offload: if self.ssd_offload:
self.ssd_buffer.to_disk() for p in self.params:
assert isinstance(p, SsdFlatParameter)
p.to_file(permit_when_tensor_none=True)
def _register_pre_backward_hooks(self, outputs: Any) -> Any: def _register_pre_backward_hooks(self, outputs: Any) -> Any:
"""Register pre-backward hook to run before the wrapped module's """Register pre-backward hook to run before the wrapped module's
...@@ -1405,7 +1388,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1405,7 +1388,7 @@ class FullyShardedDataParallel(nn.Module):
# Note, both ``self._rebuild_full_params`` and ``self._use_full_params`` are # Note, both ``self._rebuild_full_params`` and ``self._use_full_params`` are
# idempotent. So in case they are called unnecessarily, they don't incur much # idempotent. So in case they are called unnecessarily, they don't incur much
# overhead. # overhead.
if self.ssd_offload or self.reshard_after_forward: if self.reshard_after_forward:
self._rebuild_full_params() self._rebuild_full_params()
else: else:
self._use_full_params() self._use_full_params()
...@@ -1730,7 +1713,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1730,7 +1713,7 @@ class FullyShardedDataParallel(nn.Module):
for m in self.modules(): # includes self for m in self.modules(): # includes self
if isinstance(m, FullyShardedDataParallel): if isinstance(m, FullyShardedDataParallel):
_finalize_parameters(m) _finalize_parameters(m)
self._free_ssd_offload() m._free_ssd_offload()
m._pre_backward_hook_has_run = False m._pre_backward_hook_has_run = False
if any(p.requires_grad for p in m.parameters()): if any(p.requires_grad for p in m.parameters()):
# Check if the module has params and if any of them has # Check if the module has params and if any of them has
...@@ -1808,12 +1791,9 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1808,12 +1791,9 @@ class FullyShardedDataParallel(nn.Module):
p.data = p.data[: p._orig_size.numel()].view(p._orig_size) p.data = p.data[: p._orig_size.numel()].view(p._orig_size)
if self.ssd_offload: if self.ssd_offload:
self.ssd_buffer.from_disk(self.buffer_size) for p in self.params:
assert isinstance(p, SsdFlatParameter)
# The params are on disk and need to be moved to the CPU. p.to_tensor()
for p, handle in zip(self.params, self.ssd_buffer.get_tensors()):
p._fp32_shard = handle.get_tensor().view(p._shard_size) # type: ignore
p.data = p._fp32_shard
self.has_full_params = False self.has_full_params = False
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
from contextlib import contextmanager from contextlib import contextmanager
from itertools import chain from itertools import chain
import tempfile
import typing import typing
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
...@@ -30,6 +31,7 @@ import torch ...@@ -30,6 +31,7 @@ import torch
from torch import Tensor from torch import Tensor
import torch.nn as nn import torch.nn as nn
from fairscale.experimental.nn.ssd_offload import SsdFlatParameter
from fairscale.utils.state_dict import replace_by_prefix_ from fairscale.utils.state_dict import replace_by_prefix_
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -114,6 +116,7 @@ class FlatParameter(nn.Parameter): ...@@ -114,6 +116,7 @@ class FlatParameter(nn.Parameter):
# Static types. # Static types.
FlatTypes = Union[FlatParameter, SsdFlatParameter]
ParamGroups = Optional[Union[List[List[nn.Parameter]], List[nn.Parameter]]] ParamGroups = Optional[Union[List[List[nn.Parameter]], List[nn.Parameter]]]
...@@ -147,7 +150,14 @@ class FlattenParamsWrapper(nn.Module): ...@@ -147,7 +150,14 @@ class FlattenParamsWrapper(nn.Module):
prefix will be added to those names. prefix will be added to those names.
""" """
def __init__(self, module: nn.Module, param_list: ParamGroups = None, flat_param_names: Optional[List[str]] = None): def __init__(
self,
module: nn.Module,
param_list: ParamGroups = None,
flat_param_names: Optional[List[str]] = None,
ssd_offload: bool = False,
ssd_directory: str = "",
):
super().__init__() super().__init__()
self._fpw_module = module self._fpw_module = module
self.is_flattened = False self.is_flattened = False
...@@ -195,7 +205,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -195,7 +205,7 @@ class FlattenParamsWrapper(nn.Module):
# support. # support.
raise ValueError(f"Incorrect param groups {len(overall_param_set)} vs {self.num_param_managed}") raise ValueError(f"Incorrect param groups {len(overall_param_set)} vs {self.num_param_managed}")
self.flat_params: List[FlatParameter] = [] self.flat_params: List[FlatTypes] = []
# Prepare flat param names. # Prepare flat param names.
if flat_param_names is None: if flat_param_names is None:
...@@ -205,10 +215,16 @@ class FlattenParamsWrapper(nn.Module): ...@@ -205,10 +215,16 @@ class FlattenParamsWrapper(nn.Module):
if len(flat_param_names) != len(set(flat_param_names)): if len(flat_param_names) != len(set(flat_param_names)):
raise ValueError("Each flat param must be given a unique name") raise ValueError("Each flat param must be given a unique name")
self.flat_param_names = [f"flat_param_{n}" for n in flat_param_names] self.flat_param_names = [f"flat_param_{n}" for n in flat_param_names]
flat_param: Optional[FlatTypes] = None
# Init all flat_params. # Init all flat_params.
for new_p_set in self._param_sets: for new_p_set in self._param_sets:
params, param_infos, shared_param_infos = self._init_flatten_params(new_p_set) params, param_infos, shared_param_infos = self._init_flatten_params(new_p_set)
if ssd_offload:
assert ssd_directory != ""
(handle, fname) = tempfile.mkstemp(dir=ssd_directory, suffix="ssd_buf_param")
flat_param = SsdFlatParameter(params=params, filename=fname, requires_grad=params[0].requires_grad)
else:
flat_param = FlatParameter(params, params[0].requires_grad) flat_param = FlatParameter(params, params[0].requires_grad)
flat_param._param_infos = param_infos flat_param._param_infos = param_infos
flat_param._shared_param_infos = shared_param_infos flat_param._shared_param_infos = shared_param_infos
...@@ -288,7 +304,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -288,7 +304,7 @@ class FlattenParamsWrapper(nn.Module):
def _shared_param_infos(self) -> Iterator[Tuple[str, str, nn.Module, str, nn.Module, str]]: def _shared_param_infos(self) -> Iterator[Tuple[str, str, nn.Module, str, nn.Module, str]]:
return chain(*[p._shared_param_infos for p in self.flat_params]) return chain(*[p._shared_param_infos for p in self.flat_params])
def _flatten_params(self, flat_params: List[FlatParameter]) -> None: def _flatten_params(self, flat_params: List[FlatTypes]) -> None:
"""Flatten the managed parameters and replaced the original """Flatten the managed parameters and replaced the original
attributes with views to the flat params. attributes with views to the flat params.
""" """
......
from typing import Any, BinaryIO, Union import os
import pickle
from typing import Any, BinaryIO, Callable, IO, Union
def save(obj, f: Union[str, BinaryIO]) -> None: ... DEFAULT_PROTOCOL: int = 2
def save(obj, f: Union[str, os.PathLike, BinaryIO, IO[bytes]],
pickle_module: Any=pickle, pickle_protocol: int=DEFAULT_PROTOCOL, _use_new_zipfile_serialization: bool=True) -> None: ...
def load(f: Union[str, BinaryIO], map_location) -> Any: ... def load(f: Union[str, BinaryIO], map_location) -> Any: ...
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
""" """
Testing SsdBuffer and SsdTensorHandle modules. Testing SsdFlatParameter and SsdTensorHandle modules.
""" """
import tempfile import tempfile
...@@ -38,6 +38,8 @@ def test_write_read(): ...@@ -38,6 +38,8 @@ def test_write_read():
def test_ssd_handle_dispatch_fwd(): def test_ssd_handle_dispatch_fwd():
_init()
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((128)) orig_tensor = torch.randn((128))
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor) ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
...@@ -54,6 +56,8 @@ def test_ssd_handle_dispatch_fwd(): ...@@ -54,6 +56,8 @@ def test_ssd_handle_dispatch_fwd():
def test_ssd_handle_dispatch_bwd(): def test_ssd_handle_dispatch_bwd():
_init()
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4), requires_grad=True) orig_tensor = torch.randn((4, 4), requires_grad=True)
orig_copy = orig_tensor.clone().detach().requires_grad_(True) orig_copy = orig_tensor.clone().detach().requires_grad_(True)
...@@ -68,98 +72,89 @@ def test_ssd_handle_dispatch_bwd(): ...@@ -68,98 +72,89 @@ def test_ssd_handle_dispatch_bwd():
y1.sum().backward() y1.sum().backward()
y2.sum().backward() y2.sum().backward()
# TODO: PJ/ASenable assert once Tensor._make_subclass can properly define the tensor's shape assert torch.equal(ssd_handle.grad, orig_copy.grad)
# assert torch.equal(ssd_handle.grad, orig_copy.grad)
def test_ssd_buffer_basic(): def test_ssd_handle_train_simple():
_init() _init()
with tempfile.NamedTemporaryFile() as f:
refa_tensor = torch.rand((128), dtype=torch.float32)
refb_tensor = torch.rand((128), dtype=torch.float32)
refc_tensor = torch.rand((128), dtype=torch.float32)
ssd_buf = so.SsdBuffer(1024, f.name)
hdl_a = ssd_buf.insert(refa_tensor)
hdl_b = ssd_buf.insert(refb_tensor)
hdl_c = ssd_buf.insert(refc_tensor)
assert hdl_a.is_available() with tempfile.NamedTemporaryFile() as f:
assert hdl_b.is_available() orig_tensor = torch.randn((4, 4), requires_grad=True)
assert hdl_c.is_available()
assert torch.equal(refa_tensor, hdl_a.get_tensor())
assert torch.equal(refb_tensor, hdl_b.get_tensor())
assert torch.equal(refc_tensor, hdl_c.get_tensor())
tensors = ssd_buf.get_tensors() with torch.no_grad():
assert hdl_a is tensors[0] orig_copy = torch.empty_like(orig_tensor)
assert hdl_b is tensors[1] orig_copy.copy_(orig_tensor)
assert hdl_c is tensors[2] orig_copy.requires_grad = True
# test read_into_tensor when handle.is_available() ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
b_tensor_copy1 = torch.empty_like(refb_tensor) ssd_handle.set_file_params(f.name, 0)
hdl_b.copy_into_tensor(b_tensor_copy1) ssd_handle.to_file(release_tensor_after_write=True)
assert torch.equal(refb_tensor, b_tensor_copy1)
# remove references so memory will be cleaned up assert torch.equal(ssd_handle.to_tensor(), orig_tensor)
buffer = None optimizer_ssd = torch.optim.SGD([ssd_handle], lr=0.1)
optimizer_orig = torch.optim.SGD([orig_copy], lr=0.1)
ssd_buf.to_disk() y1 = ssd_handle + 1
optimizer_ssd.zero_grad()
y1.sum().backward()
optimizer_ssd.step()
assert hdl_a.filename == f.name y2 = orig_copy + 1
assert hdl_b.filename == f.name optimizer_orig.zero_grad()
assert hdl_c.filename == f.name y2.sum().backward()
optimizer_orig.step()
assert hdl_a.offset == 0 # make sure we are using the file version not the cached tensor
assert hdl_b.offset == 128 ssd_handle.point_to_file(f.name, 0)
assert hdl_c.offset == 256 assert torch.equal(ssd_handle.to_tensor(), orig_copy)
assert not hdl_a.is_available()
assert not hdl_b.is_available()
assert not hdl_c.is_available()
# test read_into_tensor when !handle.is_available() def test_ssd_flat_param_train_simple():
b_tensor_copy2 = torch.empty_like(refb_tensor) _init()
hdl_b.copy_into_tensor(b_tensor_copy2) with tempfile.NamedTemporaryFile() as f:
assert torch.equal(refb_tensor, b_tensor_copy2) orig_tensor = torch.randn((4, 4))
ssd_buf.from_disk(384) with torch.no_grad():
orig_copy = torch.empty_like(orig_tensor)
orig_copy.copy_(orig_tensor)
param = torch.nn.Parameter(orig_copy)
assert hdl_a.is_available() ssd_flat_param = so.SsdFlatParameter([param], f.name, True)
assert hdl_b.is_available()
assert hdl_c.is_available()
assert torch.equal(refa_tensor, hdl_a.get_tensor()) assert torch.equal(list(ssd_flat_param.get_param_views())[0], orig_tensor)
assert torch.equal(refb_tensor, hdl_b.get_tensor()) optimizer_ssd = torch.optim.SGD([ssd_flat_param], lr=0.1)
assert torch.equal(refc_tensor, hdl_c.get_tensor()) optimizer_orig = torch.optim.SGD([param], lr=0.1)
y1 = ssd_flat_param + 1
optimizer_ssd.zero_grad()
y1.sum().backward()
optimizer_ssd.step()
def test_ssd_buffer_too_small_from_disk(): y2 = param + 1
_init() optimizer_orig.zero_grad()
with tempfile.NamedTemporaryFile() as f: y2.sum().backward()
refa_tensor = torch.rand((128), dtype=torch.float32) optimizer_orig.step()
ssd_buf = so.SsdBuffer(128, f.name)
hdl_a = ssd_buf.insert(refa_tensor)
ssd_buf.to_disk()
with pytest.raises(RuntimeError): # make sure we are using the file version not the cached tensor
ssd_buf.from_disk(127) ssd_flat_param.point_to_file(f.name, 0)
assert torch.equal(list(ssd_flat_param.get_param_views())[0], param)
def test_ssd_buffer_null_buffer(): def test_ssd_flat_parameter_basic():
_init() _init()
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
refa_tensor = torch.rand((128), dtype=torch.float32) refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
ssd_buf = so.SsdBuffer(128, f.name) refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
hdl_a = ssd_buf.insert(refa_tensor) refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32))
ssd_buf.to_disk() ssd_flat_param = so.SsdFlatParameter([refa_param, refb_param, refc_param], f.name, False)
with pytest.raises(AssertionError): param_views = list(ssd_flat_param.get_param_views())
hdl_a = ssd_buf.insert(refa_tensor)
with pytest.raises(AssertionError): assert refa_param.shape == param_views[0].shape
ssd_buf.can_alloc(128) assert refb_param.shape == param_views[1].shape
assert refc_param.shape == param_views[2].shape
with pytest.raises(AssertionError): assert torch.equal(refa_param, param_views[0])
hdl = ssd_buf.allocate(128) assert torch.equal(refb_param, param_views[1])
assert torch.equal(refc_param, param_views[2])
ssd_flat_param.to_file()
...@@ -102,6 +102,8 @@ class DistributedTest(unittest.TestCase): ...@@ -102,6 +102,8 @@ class DistributedTest(unittest.TestCase):
# Confirm we get the same behavior using FullyShardedDataParallel. # Confirm we get the same behavior using FullyShardedDataParallel.
if config.get("ssd_offload", False): if config.get("ssd_offload", False):
config["offload_config"] = OffloadConfig(offload_type="ssd_offload") config["offload_config"] = OffloadConfig(offload_type="ssd_offload")
# ssd offload only supports flatten_params ATM
config["flatten_parameters"] = True
del config["ssd_offload"] del config["ssd_offload"]
model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config) model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
...@@ -121,7 +123,7 @@ class DistributedTest(unittest.TestCase): ...@@ -121,7 +123,7 @@ class DistributedTest(unittest.TestCase):
assert isinstance(metadata, dict) assert isinstance(metadata, dict)
keys = ["reshard_after_forward", "mixed_precision", "flatten_parameters", "nested_wrapping"] keys = ["reshard_after_forward", "mixed_precision", "nested_wrapping"]
CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([True, False], repeat=len(keys))] CONFIG_OPTIONS = [[dict(zip(keys, config))] for config in itertools.product([True, False], repeat=len(keys))]
...@@ -152,7 +154,7 @@ class TestSsdMemory(DistributedTest): ...@@ -152,7 +154,7 @@ class TestSsdMemory(DistributedTest):
time_keeper.print_time("CPU_MODEL", 1.0) time_keeper.print_time("CPU_MODEL", 1.0)
with tempfile.TemporaryDirectory() as current_tempdir: with tempfile.TemporaryDirectory() as current_tempdir:
config["offload_config"] = OffloadConfig(offload_type="ssd_offload", ssd_filepath_dir=current_tempdir) config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
model = FullyShardedDataParallel(model, **config) model = FullyShardedDataParallel(model, **config)
time_keeper.print_time("FSDP_MODEL", 1.0) time_keeper.print_time("FSDP_MODEL", 1.0)
...@@ -223,7 +225,9 @@ class TestModuleProperties(DistributedTest): ...@@ -223,7 +225,9 @@ class TestModuleProperties(DistributedTest):
with tempfile.TemporaryDirectory() as current_tempdir: with tempfile.TemporaryDirectory() as current_tempdir:
if config["ssd_offload"]: if config["ssd_offload"]:
config["offload_config"] = OffloadConfig(offload_type="ssd_offload", ssd_filepath_dir=current_tempdir) config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
# ssd offload only supports flatten_params ATM
config["flatten_parameters"] = True
del config["ssd_offload"] del config["ssd_offload"]
model = FullyShardedDataParallel(before_wrap_model, **config) model = FullyShardedDataParallel(before_wrap_model, **config)
...@@ -260,6 +264,47 @@ class TestSsdLoading(DistributedTest): ...@@ -260,6 +264,47 @@ class TestSsdLoading(DistributedTest):
def test_transformer_parameterized(self, config): def test_transformer_parameterized(self, config):
spawn_and_init(functools.partial(self._test_identical_outputs_eval, TransformerWithSharedParams, config)) spawn_and_init(functools.partial(self._test_identical_outputs_eval, TransformerWithSharedParams, config))
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_ssd_offloading_train_flatten_params_wrapper(self, config):
test_fn = functools.partial(self._test_ssd_offloading_train_flatten_params_wrapper, config=config)
spawn_and_init(test_fn)
@classmethod
def _test_ssd_offloading_train_flatten_params_wrapper(self, rank, group, config):
SIZE = 16 * 16
model = SimpleLinear(group, input_size=SIZE, output_size=SIZE, layers=4)
with tempfile.TemporaryDirectory() as current_tempdir:
config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
config["flatten_parameters"] = True
nested_wrapping = config["nested_wrapping"]
del config["nested_wrapping"]
if nested_wrapping:
model = FullyShardedDataParallel(
NestedWrappedModule(group, wrap_everything=True, wrapper_config=config)
)
else:
model = FullyShardedDataParallel(model, **config)
model_device = torch.device("cuda")
model.train()
optim = torch.optim.SGD(model.parameters(), lr=4, momentum=0.9)
# Inputs always cuda regardless of move_grads_cpu, or model.device
with torch.cuda.amp.autocast(enabled=config.get("mixed_precision", False)):
for i in range(10):
optim.zero_grad()
input = model.get_input(torch.device("cuda"))
output = model(*input)
loss = model.module.get_loss(input, output).to(model_device)
assert loss.dtype == torch.float32
model.module.run_backward(loss)
optim.step()
if isinstance(model, FullyShardedDataParallel):
model.assert_state(TrainingState.IDLE)
@classmethod @classmethod
def _test_ssd_offload_eval(self, rank, group, config): def _test_ssd_offload_eval(self, rank, group, config):
model = TransformerWithSharedParams(group) model = TransformerWithSharedParams(group)
...@@ -267,9 +312,10 @@ class TestSsdLoading(DistributedTest): ...@@ -267,9 +312,10 @@ class TestSsdLoading(DistributedTest):
nested_wrapping = config["nested_wrapping"] nested_wrapping = config["nested_wrapping"]
del config["nested_wrapping"] del config["nested_wrapping"]
config["flatten_parameters"] = True
with tempfile.TemporaryDirectory() as current_tempdir: with tempfile.TemporaryDirectory() as current_tempdir:
config["offload_config"] = OffloadConfig(offload_type="ssd_offload", ssd_filepath_dir=current_tempdir) config["offload_config"] = OffloadConfig(offload_type="ssd_offload", dir=current_tempdir)
if nested_wrapping: if nested_wrapping:
model = FullyShardedDataParallel( model = FullyShardedDataParallel(
NestedWrappedModule(group, wrap_everything=True, wrapper_config=config) NestedWrappedModule(group, wrap_everything=True, wrapper_config=config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment