Unverified Commit 51b53ddb authored by Paul Johnson's avatar Paul Johnson Committed by GitHub
Browse files

[FSDP] ssd_offload fixing backward path (grad_fn) for SsdFlatParameter and...

[FSDP] ssd_offload fixing backward path (grad_fn) for SsdFlatParameter and SsdFlatParameterView (#974)

* [FSDP] fixing backward path for SsdFlatParameter and SsdFlatParameterView when overriding .data

* Get ssd_offload unit tests passing

* [FSDP] get all test_fsdp_offload tests passing w/ ssd_offload on

* Update changelog
parent 58ccb166
......@@ -12,7 +12,40 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- FSDP: Add pickle/unpickle support for SsdTensorHandle (and derived classes),
verified that FSDP models w/ ssd_offload enabled can correctly call model.state_dict()
and model.load_state_dict(...) and thus successfully checkpoint and recover parameters
stored as SsdFlatParameters.
stored as SsdFlatParameters. [#964]
- FSDP Ssd Offload: [#974]
* add __setattr__ function for SsdTensorHandle/SsdFlatParameter to capture when .data is overridden
and perform necessary checks/updates before letting the tensor metadata be updated.
* There was a bug in pytorch core that caused re-assigning TensorSubclass .data to
disable the __torch_dispatch__ mechanism and effectively revert back to a normal
tensor. Since ssd_offload feature uses __torch_dispatch__ extensively and FSDP overrides
.data, ssd_offload now can only be used if pytorch version is 1.12.0 or later
(currently pytorch-nightly release). pytest tests are disabled and trying to import
ssd_offload or use FSDP with ssd_offload enabled will raise ImportError Exceptions.
* Enhance storage_state ON_CPU into ON_CPU_CLEAN and ON_CPU_DIRTY. ON_CPU_CLEAN
indicates that the .tensor value and the value stored on disk are identical, and no
writes are needed when calling .to_file(). ON_CPU_DIRTY indicates .tensor does not
match value stored on disk, and a write is necessary. This is disabled in FSDP as
it does not currently flush variables to disk as soon as optimizer modifies values.
* Fix detection in SsdTensorHandle __torch_dispatch__ when it is used in an in-place
operator, or as an output, then calling mark_dirty().
* Fix grad_fn on SsdFlatParameterViews so that both ssd_flat_param.view.grad_fn points
to a split and view on ssd_flat_parameter so that when X.backwards() is called. ssd_flat_param.grad
is properly updated in the backward pass.
* Add unit tests to verify grad_fn and backward hooks are called appropriately on
SsdFlatParameters/SsdFlatParameterViews
* Implement SsdFlatParameter.from_tensor() direct_to_file option. This allows creating
and SsdFlatParameter and directly writing it to disk, rather than creating a new tensor
first. This prevents another copy of tensors to be instantiated and can save memory when
creating very large SsdFlatParameters.
* Add SsdFlatParameterViewProperty for overriding a parameter in an nn.Module
similar to pytorch core's parameterization code path. This allows SsdFlatParameterViews
to be updated as SsdFlatParameter.data is overridden by replacing Layer.weight with a
property whose getter returns ssd_flat_param.views[view_id].
* Added example SsdFlatParameterViewParameterization on how the existing parameterization
method could be used instead of SsdFlatParameterViewProperty. But this method will result
in memory inefficiencies due to parameterization always keeping a copy of the original
parameter internally.
### Fixed
......
......@@ -11,18 +11,24 @@ import io
import os
import pickle
from types import TracebackType
from typing import IO, Any, BinaryIO, Iterator, List, Optional, Sequence, Tuple, Type, Union
from typing import IO, Any, BinaryIO, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Type, Union
import numpy as np
import torch
from torch.serialization import DEFAULT_PROTOCOL as DEFAULT_PROTOCOL
from fairscale.utils import torch_version
try:
from torch.utils._pytree import tree_map
except ImportError:
# The PyTorch version(<1.9) we test with does not support the tree_map API.
pass
if torch_version() < (1, 12, 0):
raise ImportError(
f"ssd_offload only works on torch versions 1.12.0 and beyond, but torch version is: {torch.__version__}"
)
DEFAULT_CHUNK_SIZE = 2048 * 2048
......@@ -69,7 +75,10 @@ def read(input_tensor: torch.Tensor, filename: str, file_offset_bytes: int = 0)
chunk_start = i * chunk_size_bytes
chunk_end = min(size_in_bytes, chunk_start + chunk_size_bytes)
data_read = f.readinto(input_tensor_mv[chunk_start:chunk_end])
assert data_read == chunk_end - chunk_start
if data_read != chunk_end - chunk_start:
raise RuntimeError(
f"Attempted to read {chunk_end - chunk_start} more bytes from {filename}, but only read: {data_read} bytes. Total Bytes read = {chunk_start + data_read}, total bytes expected: {size_in_bytes}"
)
class StorageState(Enum):
......@@ -82,7 +91,8 @@ class StorageState(Enum):
UNALLOCATED = auto()
ON_DISK = auto()
ON_CPU = auto()
ON_CPU_CLEAN = auto()
ON_CPU_DIRTY = auto()
class SsdTensorHandle(torch.Tensor):
......@@ -109,12 +119,26 @@ class SsdTensorHandle(torch.Tensor):
@staticmethod
def __new__(
cls: Type[SsdTensorHandle], shape: torch.Size, dtype: torch.dtype, requires_grad: bool = False
cls: Type[SsdTensorHandle],
shape: torch.Size,
dtype: torch.dtype,
requires_grad: bool = False,
device: torch.device = torch.device("cpu"),
flush_on_dirty: bool = True,
allow_unsafe_changes: bool = False,
) -> SsdTensorHandle:
r = super(SsdTensorHandle, cls)._make_wrapper_subclass(cls, shape, dtype=dtype, requires_grad=requires_grad) # type: ignore
r = super(SsdTensorHandle, cls)._make_wrapper_subclass(cls, shape, dtype=dtype, requires_grad=requires_grad, device=device) # type: ignore
return r
def __init__(self, shape: torch.Size, dtype: torch.dtype, requires_grad: bool) -> None:
def __init__(
self,
shape: torch.Size,
dtype: torch.dtype,
requires_grad: bool,
device: torch.device = torch.device("cpu"),
flush_on_dirty: bool = True,
allow_unsafe_changes: bool = False,
) -> None:
self._unpickle_f: Optional[Union[BinaryIO, IO[bytes]]] = None
self._shape = shape
......@@ -128,8 +152,17 @@ class SsdTensorHandle(torch.Tensor):
self.offset = -1
# valid if loaded to memory
self.tensor: Optional[torch.Tensor] = None
self.requires_grad = requires_grad
self.storage_state = StorageState.UNALLOCATED
self.flush_on_dirty = flush_on_dirty
self.allow_unsafe_changes = allow_unsafe_changes
def mark_dirty(self) -> None:
assert self.tensor is not None
assert self.storage_state in [StorageState.ON_CPU_CLEAN, StorageState.ON_CPU_DIRTY]
self.storage_state = StorageState.ON_CPU_DIRTY
# hack to force write on mark_dirty
if self.flush_on_dirty:
self.to_file()
@classmethod
def from_file(
......@@ -143,7 +176,7 @@ class SsdTensorHandle(torch.Tensor):
@classmethod
def from_tensor(cls: Type[SsdTensorHandle], tensor: torch.Tensor) -> SsdTensorHandle:
"""Returns a new SsdTensorHandle from a tensor."""
handle = cls(shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad)
handle = cls(shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad, device=tensor.device)
handle.point_to_tensor(tensor)
return handle
......@@ -165,10 +198,11 @@ class SsdTensorHandle(torch.Tensor):
def point_to_tensor(self, tensor: torch.Tensor) -> None:
assert self.tensor is None
assert self._shape == tensor.shape
if not self.allow_unsafe_changes:
assert self._shape == tensor.shape
assert self._dtype == tensor.dtype
self.tensor = tensor
self.storage_state = StorageState.ON_CPU
self.storage_state = StorageState.ON_CPU_DIRTY
# if resizing a handle that is part of an ssd buffer, care must be taken that the new size
# doesn't conflict with adjacent handles!
......@@ -185,21 +219,33 @@ class SsdTensorHandle(torch.Tensor):
if self.tensor is not None:
return self.tensor
else:
result_tensor = torch.empty(size=self._shape, dtype=self._dtype, requires_grad=self.requires_grad)
if self.device != torch.device("cpu"):
raise RuntimeError(
f"to_tensor called on an SsdTensorHandle when the tensor has been offloaded to disk. self.device = {self.device}, it should be {torch.device('cpu')}. Some unexpected .data override has occured!!"
)
result_tensor = torch.empty(size=self.shape, dtype=self.dtype, requires_grad=self.requires_grad)
self.copy_into_tensor(result_tensor)
self.tensor = result_tensor
self.storage_state = StorageState.ON_CPU
self.storage_state = StorageState.ON_CPU_CLEAN
return self.tensor
def to_file(self, permit_when_tensor_none: bool = False, release_tensor_after_write: bool = True) -> None:
"""Saves the tensor to disk and releases memory if specified."""
assert self.tensor is not None or permit_when_tensor_none
# if it's available in Memory but not modified, no need to write-back
if self.tensor is not None:
write(self.tensor, self.filename, self.offset * self.tensor.element_size())
if self.storage_state is StorageState.ON_CPU_DIRTY:
if self.device != torch.device("cpu"):
raise RuntimeError(
f"to_file called on an SsdTensorHandle when self.device = {self.device}, it should be {torch.device('cpu')}. Some unexpected .data override has occured!!"
)
write(self.tensor, self.filename, self.offset * self.tensor.element_size())
if release_tensor_after_write:
self.tensor = None
self.storage_state = StorageState.ON_DISK
else:
self.storage_state = StorageState.ON_CPU_CLEAN
def copy_into_tensor(self, tensor: torch.Tensor) -> None:
"""Copies SsdTensorHandle's data into the given tensor.
......@@ -211,7 +257,9 @@ class SsdTensorHandle(torch.Tensor):
function. This can be useful for calls like named_parameters() when
the tensor is already offloaded to disk.
"""
assert self._shape == tensor.shape
# ideally this should be checked but .data shenanigans forces it to
# be disabled due to the way FSDP shards parameters
# assert self._shape == tensor.shape
assert self._dtype == tensor.dtype
if self.tensor is not None:
tensor.copy_(self.tensor)
......@@ -229,25 +277,48 @@ class SsdTensorHandle(torch.Tensor):
versions to track if modifications have been made. If we detect changes to the
tensor, we write it to the file maintained by the Handle.
"""
func_name = func.overloadpacket.__name__
ssd_tensor_handles = []
def unwrap(e: Any) -> torch.Tensor:
if isinstance(e, SsdTensorHandle):
t = e.to_tensor()
ssd_tensor_handles.append((e, t._version)) # type: ignore
ssd_tensor_handles.append(e)
return t
else:
return e
r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
for e, saved_version in ssd_tensor_handles:
inplace_is_this_tensor = func.__name__[-1] == "_" and e is args[0]
for e in ssd_tensor_handles:
inplace_is_this_tensor = (
(func_name.endswith("_") and not func_name.endswith("__")) or func_name.startswith("__i")
) and e is args[0]
out_is_this_tensor = False if "out" not in kwargs else e is kwargs["out"]
if inplace_is_this_tensor or out_is_this_tensor:
e.to_file()
e.mark_dirty()
return r
def __setattr__(self, name: str, value: Any) -> None:
if name == "data":
assert isinstance(value, torch.Tensor)
if not self.allow_unsafe_changes:
# Respect .data changes, and the user better know what they are doing!
if self.storage_state == StorageState.ON_CPU_DIRTY:
raise RuntimeError(
"Attempting to override tensor when the existing tensor is dirty, this is an error!"
)
if value.shape != self.shape:
raise RuntimeError(
f"Attempting to override tensor metadata using .data to change shape of tensor. Orig shape: {self.shape} New shape: {value.shape}"
)
if value.requires_grad != self.requires_grad:
raise RuntimeError(
f"Attempting to override tensor metadata using .data to change requires_grad. Orig value: {self.requires_grad} New value: {value.requires_grad}"
)
self.tensor = value
super(SsdTensorHandle, self).__setattr__(name, value)
@classmethod
def __unpickle__(
cls: Type[SsdTensorHandle], shape: torch.Size, dtype: torch.dtype, requires_grad: bool, filename: str
......@@ -264,7 +335,7 @@ class SsdTensorHandle(torch.Tensor):
head, tail = os.path.split(self.filename)
filename = os.path.join(self.override_directory_path, tail)
if self.is_available():
byte_iter = iter(TensorChunkingIterator(self.tensor))
byte_iter = iter(TensorChunkingIterator(self.tensor)) # ignore: type
else:
byte_iter = iter(
FileChunkingIterator(self.filename, expected_size_bytes=self.numel() * self.element_size())
......@@ -358,19 +429,29 @@ class TorchSaver:
class SsdParameter(SsdTensorHandle, torch.nn.Parameter):
@classmethod
def from_tensor(cls: Type[SsdParameter], tensor: SsdTensorHandle) -> SsdParameter: # type: ignore
r = cls(tensor.shape, tensor.dtype, tensor.requires_grad)
r = cls(tensor.shape, tensor.dtype, tensor.requires_grad, device=tensor.device)
r.point_to_tensor(tensor)
return r
@staticmethod
def __new__(
cls: Type[SsdParameter], shape: torch.Size, dtype: torch.dtype, requires_grad: bool = True
cls: Type[SsdParameter],
shape: torch.Size,
dtype: torch.dtype,
requires_grad: bool = True,
device: torch.device = torch.device("cpu"),
) -> SsdParameter:
r = super(SsdParameter, cls).__new__(cls, shape, dtype=dtype, requires_grad=requires_grad)
r = super(SsdParameter, cls).__new__(cls, shape=shape, dtype=dtype, requires_grad=requires_grad, device=device)
return r # type: ignore
def __init__(self, shape: torch.Size, dtype: torch.dtype, requires_grad: bool = True) -> None:
super(SsdParameter, self).__init__(shape, dtype, requires_grad)
def __init__(
self,
shape: torch.Size,
dtype: torch.dtype,
requires_grad: bool = True,
device: torch.device = torch.device("cpu"),
) -> None:
super(SsdParameter, self).__init__(shape=shape, dtype=dtype, requires_grad=requires_grad, device=device)
class SsdFlatParameter(SsdParameter):
......@@ -381,7 +462,11 @@ class SsdFlatParameter(SsdParameter):
"""
def __new__(
cls: Type[SsdFlatParameter], shapes: Sequence[torch.Size], dtype: torch.dtype, requires_grad: bool = True
cls: Type[SsdFlatParameter],
shapes: Sequence[torch.Size],
dtype: torch.dtype,
requires_grad: bool = True,
device: torch.device = torch.device("cpu"),
) -> SsdFlatParameter:
"""Make an object using the parent's __new__ function."""
......@@ -390,10 +475,18 @@ class SsdFlatParameter(SsdParameter):
raise ValueError("An non-empty list or tuple argument is needed")
size = sum([np.prod(s) for s in shapes])
r = super(SsdFlatParameter, cls).__new__(cls, torch.Size((size,)), dtype=dtype, requires_grad=requires_grad)
r = super(SsdFlatParameter, cls).__new__(
cls, torch.Size((size,)), dtype=dtype, requires_grad=requires_grad, device=device
)
return r # type: ignore
def __init__(self, shapes: Sequence[torch.Size], dtype: torch.dtype, requires_grad: bool = True):
def __init__(
self,
shapes: Sequence[torch.Size],
dtype: torch.dtype,
requires_grad: bool = True,
device: torch.device = torch.device("cpu"),
):
"""Initialize the _param_numels and _param_shapes lists."""
self._param_shapes = shapes
self._param_numels = [np.prod(s) for s in shapes]
......@@ -402,6 +495,7 @@ class SsdFlatParameter(SsdParameter):
self.numel() <= total_numels
), f"Something wrong with __new__ method, {self.numel()} vs. {sum(self._param_numels)}"
self.views: List[SsdFlatParameterView] = []
# These are set by FPW class below, not by this class itself.
self._param_infos: List[Tuple[str, torch.nn.Module, str]] = []
self._shared_param_infos: List[Tuple[str, str, torch.nn.Module, str, torch.nn.Module, str]] = []
......@@ -410,6 +504,31 @@ class SsdFlatParameter(SsdParameter):
shape=torch.Size((total_numels,)), dtype=dtype, requires_grad=requires_grad
)
def __setattr__(self, name: str, value: Any) -> None:
super(SsdFlatParameter, self).__setattr__(name, value)
if name == "data":
# if .data has changed, we need to totally destroy any existing views because things
# like device might have changed. It won't destroy any pointers to those views outside
# of here, however resetting self.views will trigger the old view's assertion in
# __torch_dispatch__ that it is the current view of it's parent object
self.views = []
self._refresh_views()
def _invalidate_views(self) -> None:
for v in self.views:
v.tensor = None
@torch.enable_grad()
def _refresh_views(self) -> None:
if self._shape != self.shape:
self.views = []
return
if len(self.views) == 0:
self.views = [s.view(v) for s, v in zip(self.split(self._param_numels), self._param_shapes)] # type: ignore
else:
for v, t, s in zip(self.views, self.tensor.split(self._param_numels), self._param_shapes):
v.tensor = t.view(s)
def get_param_views(self, external_data: Optional[torch.Tensor] = None) -> Iterator[torch.Tensor]:
"""Return a generator of views that map to the original parameters."""
# Note, self.data could be sharded, so its numel is <= to the sum.
......@@ -425,7 +544,15 @@ class SsdFlatParameter(SsdParameter):
)
return (t.view(s) for (t, s) in zip(external_data.split(self._param_numels), self._param_shapes))
else:
return (t.view(s) for (t, s) in zip(self.split(self._param_numels), self._param_shapes))
# this needs to return SsdFlatParameterViews
if not self.is_available():
self.to_tensor()
if len(self.views) == 0:
raise RuntimeError(
"Trying to call get_param_views when self.views is empty, this means that .data games have been played and the current .data shape doesn't match the constructed shape."
)
return (v for v in self.views)
def metadata(self) -> Tuple[List[str], Sequence[torch.Size], List[int]]:
"""Return tuple of (names, shapes, numels) metadata for this flat parameter."""
......@@ -434,7 +561,11 @@ class SsdFlatParameter(SsdParameter):
@classmethod
def from_tensors(
cls: Type[SsdFlatParameter], tensors: Sequence[torch.Tensor], direct_to_file: bool = False
cls: Type[SsdFlatParameter],
tensors: Sequence[torch.Tensor],
direct_to_file: bool = False,
filename: str = "",
offset: int = 0,
) -> "SsdFlatParameter":
"""Returns a new SsdFlatParameter from a sequence of tensors."""
assert (
......@@ -449,17 +580,75 @@ class SsdFlatParameter(SsdParameter):
if any(isinstance(t, SsdFlatParameter) for t in tensors):
raise ValueError("Nesting SsdFlatParameter is not supported")
handle = cls(shapes=[t.size() for t in tensors], dtype=tensors[0].dtype, requires_grad=tensors[0].requires_grad)
requires_grad = tensors[0].requires_grad
dtype = tensors[0].dtype
device = tensors[0].device
for t in tensors:
if t.requires_grad != requires_grad:
raise RuntimeError("Not all tensors have identical requires_grad option")
if t.dtype != dtype:
raise RuntimeError("Not all tensors have identical dtype option")
if t.device != device:
raise RuntimeError("Not all tensors have identical device option")
handle = cls(
shapes=[t.size() for t in tensors],
dtype=tensors[0].dtype,
requires_grad=tensors[0].requires_grad,
device=device,
)
handle.set_file_params(filename, offset)
if direct_to_file:
assert False, "direct_to_file not implemented yet"
pass
assert filename != ""
offset = offset
for t in tensors:
write(t, handle.filename, offset)
offset += t.numel() * t.element_size()
handle.storage_state = StorageState.ON_DISK
else:
tensor = torch.cat(
[t.detach().reshape(-1) if isinstance(t, torch.nn.Parameter) else t.reshape(-1) for t in tensors], 0
)
[t.reshape(-1) if isinstance(t, torch.nn.Parameter) else t.reshape(-1) for t in tensors],
0,
).detach()
tensor.requires_grad_()
handle.point_to_tensor(tensor)
return handle
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore
func_name = func.overloadpacket.__name__
r = super(SsdFlatParameter, cls).__torch_dispatch__(func, types, args, kwargs) # type: ignore
if func_name.startswith("split"):
assert isinstance(args[0], SsdFlatParameter)
parent = args[0]
return [SsdFlatParameterView(parent, t, idx) for idx, t in enumerate(r)]
else:
return r
# need to subclass these methods to support Views
def point_to_tensor(self, tensor: torch.Tensor) -> None:
super(SsdFlatParameter, self).point_to_tensor(tensor)
self._refresh_views()
def point_to_file(self, filename: str, offset: int) -> None:
super(SsdFlatParameter, self).point_to_file(filename, offset)
self._invalidate_views()
def to_tensor(self) -> torch.Tensor:
call_refresh_views = False
if self.tensor is None:
call_refresh_views = True
result = super(SsdFlatParameter, self).to_tensor()
if call_refresh_views:
self._refresh_views()
return result
def to_file(self, permit_when_tensor_none: bool = False, release_tensor_after_write: bool = True) -> None:
super(SsdFlatParameter, self).to_file(permit_when_tensor_none, release_tensor_after_write)
self._invalidate_views()
@classmethod
def __unpickle_SFP__(
cls: Type[SsdFlatParameter],
......@@ -494,6 +683,194 @@ class SsdFlatParameter(SsdParameter):
)
class SsdFlatParameterView(torch.Tensor):
"""
Represents a view into an SsdFlatParameter. It is needed due to FSDP's usage of flattening parameters.
"""
def __new__(
cls: Type[SsdFlatParameterView], parent: SsdFlatParameter, tensor: torch.Tensor, id: int
) -> SsdFlatParameterView:
r = super(SsdFlatParameterView, cls)._make_wrapper_subclass(cls, tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad, device=tensor.device) # type: ignore
return r
def __init__(self: SsdFlatParameterView, parent: SsdFlatParameter, tensor: torch.Tensor, id: int) -> None:
self.parent = parent
self.tensor: Optional[torch.Tensor] = tensor
self.id = id
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore
"""Intercepts all operations performed on this handle object.
Before any operation, the tensor attribute is unwrapped from the handle
and used in the operation. We maintain a refernce to the tensor and its current
versions to track if modifications have been made. If we detect changes to the
tensor, we write it to the file maintained by the Handle.
"""
func_name = func.overloadpacket.__name__
ssd_tensor_handles = []
def unwrap(e: Any) -> torch.Tensor:
if isinstance(e, SsdFlatParameterView):
if not e.parent.is_available():
e.parent.to_tensor()
# first condition is to take care of the case where we are first constructing e.parent.views as a list comprehension which hasn't
# completed yet
if len(e.parent.views) != 0 and e is not e.parent.views[e.id]:
raise RuntimeError(
"This view should no longer be used as the parent object has had it's .data overwritten (e.parent.views[e.id])!!!"
)
# e.parent will ensure that e.tensor is valid and points to tensor view
t = e.tensor
ssd_tensor_handles.append(e)
return t
else:
return e
r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
for e in ssd_tensor_handles:
inplace_is_this_tensor = (
(func_name.endswith("_") and not func_name.endswith("__")) or func_name.startswith("__i")
) and e is args[0]
out_is_this_tensor = False if "out" not in kwargs else e is kwargs["out"]
if inplace_is_this_tensor or out_is_this_tensor:
e.parent.mark_dirty()
if func_name.startswith("view"):
assert isinstance(args[0], SsdFlatParameterView)
flat_view = args[0]
return SsdFlatParameterView(flat_view.parent, r, flat_view.id)
return r
# ###################################
# ### BEGIN OVERRIDE_PROPERTY FNs ###
# ###################################
# This code is taken mostly from pytorch core parameterization
# pytorch/torch/nn/utils/parametrize.py
def _inject_new_class(module: torch.nn.Module) -> None:
r"""Sets up a module to be parametrized.
This works by substituting the class of the module by a class
that extends it to be able to inject a property
Args:
module (nn.Module): module into which to inject the property
"""
cls = module.__class__
def getstate(self): # type: ignore
raise RuntimeError(
"Serialization of parametrized modules is only "
"supported through state_dict(). See:\n"
"https://pytorch.org/tutorials/beginner/saving_loading_models.html"
"#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training"
)
param_cls = type(
f"Parametrized{cls.__name__}",
(cls,),
{
"__getstate__": getstate,
},
)
module.__class__ = param_cls
module.override_properties: Dict[str, Callable[[], torch.Tensor]] = {} # type: ignore
# setattr(module, "override_properties", {})
def _inject_property(module: torch.nn.Module, property_name: str) -> None:
r"""Injects a property into module[property_name].
It assumes that the class in the module has already been modified from its
original one using _inject_new_class and that the tensor under :attr:`property_name`
has already been moved out
Args:
module (nn.Module): module into which to inject the property
property_name (str): name of the name of the property to create
"""
def get_parametrized(self: torch.nn.Module) -> torch.Tensor:
prop: Callable[[], torch.Tensor] = self.override_properties[property_name] # type: ignore
# If caching is not active, this function just evaluates the parameterization
return prop()
def set_original(self: torch.nn.Module, value: Callable[[], torch.Tensor]) -> None:
self.override_properties[property_name] = value # type: ignore
def del_fn(self: torch.nn.Module) -> None:
_remove_property(self, property_name)
setattr(module.__class__, property_name, property(get_parametrized, set_original, del_fn))
def _register_property(module: torch.nn.Module, property_name: str, property_value: Callable[[], torch.Tensor]) -> None:
has_injected_class = hasattr(module, "override_properties")
if not has_injected_class:
_inject_new_class(module)
if hasattr(module, property_name):
delattr(module, property_name)
module.override_properties[property_name] = property_value # type: ignore
_inject_property(module, property_name)
def _remove_property(module: torch.nn.Module, property_name: str, new_property_value: Optional[Any] = None) -> None:
delattr(module.__class__, property_name)
del module.override_properties[property_name] # type: ignore
# Roll back the parametrized class if no other buffer or parameter
# is currently parametrized in this class
if len(module.override_properties) == 0: # type: ignore
delattr(module, "override_properties")
# Restore class
orig_cls = module.__class__.__bases__[0]
module.__class__ = orig_cls
if new_property_value is not None:
setattr(module.__class__, property_name, new_property_value)
# #################################
# ### END OVERRIDE_PROPERTY FNs ###
# #################################
class SsdFlatParameterViewProperty:
"""
Allows for a mutable view to replace a layer's trainable parameters.
This is needed since FSDP is changing .data under the covers,
SsdFlatParameter cannot just rely on this since each view (of type SsdFlatParameterView) has
an internal representation. So every time we access a view, we need to
make sure we get the up-to-date version, and not the original version
when flattening the parameters.
"""
def __init__(self, parent: SsdFlatParameter, view_id: int) -> None:
super().__init__()
self.parent = parent
self.view_id = view_id
def __call__(self) -> SsdFlatParameterView:
return self.parent.views[self.view_id]
class SsdFlatParameterViewParameterization(torch.nn.Module):
def __init__(self, parent: SsdFlatParameter, view_id: int) -> None:
super().__init__()
self.parent = parent
self.view_id = view_id
def forward(self, *args: Any, **kwargs: Any) -> SsdFlatParameterView:
return self.parent.views[self.view_id]
class DisableMemoizationPicklerModule:
@classmethod
def Pickler(cls, data_buf: io.BytesIO, protocol: int) -> pickle.Pickler:
......
......@@ -66,7 +66,6 @@ else:
try:
import fairscale.experimental.nn.ssd_offload as ssd_offload
from fairscale.experimental.nn.ssd_offload import SsdFlatParameter
import_ssd_offload = True
except ImportError:
......@@ -399,6 +398,10 @@ class FullyShardedDataParallel(nn.Module):
self.state_dict_on_rank_0_only = state_dict_on_rank_0_only
# Experimental feature for now. Use at your own risk.
self.ssd_offload = True if offload_config and offload_config.offload_type == "ssd_offload" else False
if self.ssd_offload and not import_ssd_offload:
raise ImportError(
f"Trying to enable ssd_offload when it was not successfully imported (likely due to old torch version, current = {torch.__version__})"
)
self.gradient_predivide_factor: float = gradient_predivide_factor or self._get_gradient_predivide_factor(
self.world_size
......@@ -761,16 +764,14 @@ class FullyShardedDataParallel(nn.Module):
p._is_sharded = True
# Replace p.data with the relevant shard.
orig_data = p.data
p.data, num_padded = self._get_shard(p.data)
self.numel_padded_per_param.append(num_padded)
if self.ssd_offload:
assert isinstance(p, SsdFlatParameter)
sharded_tensor, num_padded = self._get_shard(p.data)
p.point_to_resized_tensor(sharded_tensor)
self.numel_padded_per_param.append(num_padded)
assert isinstance(p, ssd_offload.SsdParameter)
p.to_file()
else:
orig_data = p.data
p.data, num_padded = self._get_shard(p.data)
self.numel_padded_per_param.append(num_padded)
free_storage_(orig_data)
assert len(self.numel_padded_per_param) == len(self.params)
......@@ -987,7 +988,7 @@ class FullyShardedDataParallel(nn.Module):
def _move_params_to_memory(self) -> None:
"""Move params from disk to CPU."""
for p in self.params:
assert isinstance(p, SsdFlatParameter)
assert isinstance(p, ssd_offload.SsdParameter)
p.to_tensor()
def _load_state_dict(
......@@ -1138,14 +1139,19 @@ class FullyShardedDataParallel(nn.Module):
# Copy any changes made to the full params back into
# the corresponding local shards.
local_shard, _ = self._get_shard(full_tensor)
p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
if self.ssd_offload:
assert isinstance(p, ssd_offload.SsdParameter)
self._ssd_offload_reset_param_device(p)
p.point_to_tensor(local_shard.view_as(p._fp32_shard).cpu())
else:
p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
if safe_to_free:
free_storage_(full_tensor)
self.has_full_params = False
if self.ssd_offload:
# Store tensors in the SSD buffer and free param storage.
for p in self.params:
assert isinstance(p, SsdFlatParameter)
assert isinstance(p, ssd_offload.SsdParameter)
p.to_file()
else:
self._use_fp32_param_shard()
......@@ -1281,10 +1287,11 @@ class FullyShardedDataParallel(nn.Module):
# shard in pinned memory so that we can do a non-blocking transfer.
# This is only needed during training and not evaluation.
if self.ssd_offload:
assert isinstance(p, SsdFlatParameter)
assert isinstance(p, ssd_offload.SsdParameter)
# Gradients also need to be offloaded to SSD otherwise it can result in
# OOMs when the memory requirements of a model are larger than host memory.
p._cpu_grad = ssd_offload.SsdTensorHandle.from_tensor(torch.zeros_like(p.data, device="cpu"))
p._cpu_grad.allow_unsafe_changes = True
p._cpu_grad.set_file_params(p.filename + "_grad", 0)
p._cpu_grad.to_file()
else:
......@@ -1441,7 +1448,7 @@ class FullyShardedDataParallel(nn.Module):
def _free_ssd_offload(self) -> None:
if self.ssd_offload:
for p in self.params:
assert isinstance(p, SsdFlatParameter)
assert isinstance(p, ssd_offload.SsdParameter)
p.to_file(permit_when_tensor_none=True)
def _register_pre_backward_hooks(self, outputs: Any) -> Any:
......@@ -1897,8 +1904,10 @@ class FullyShardedDataParallel(nn.Module):
if self.ssd_offload:
for p in self.params:
assert isinstance(p, SsdFlatParameter)
p.to_tensor()
assert isinstance(p, ssd_offload.SsdParameter)
if not p.is_available():
self._ssd_offload_reset_param_device(p)
p.to_tensor()
self.has_full_params = False
......@@ -2175,13 +2184,25 @@ class FullyShardedDataParallel(nn.Module):
return consolidated_weights
@torch.no_grad()
def _ssd_offload_reset_param_device(self, param: Parameter) -> None:
assert isinstance(param, ssd_offload.SsdParameter)
if param.device != torch.device("cpu"):
param.data = param._fp32_shard
param.tensor = None
@torch.no_grad()
def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
"""Use FP32 shard for a list of params."""
if params is None:
params = self.params
for p in params:
p.data = p._fp32_shard
if import_ssd_offload and self.ssd_offload:
assert isinstance(p, ssd_offload.SsdParameter)
self._ssd_offload_reset_param_device(p)
p.to_tensor()
else:
p.data = p._fp32_shard
@torch.no_grad()
def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = None) -> None:
......@@ -2192,11 +2213,14 @@ class FullyShardedDataParallel(nn.Module):
for p in params:
assert p._fp16_shard is not None
alloc_storage_(p._fp16_shard, size=p._fp32_shard.size())
p._fp16_shard.copy_(
# If move_params_to_cpu is True, this will be non-blocking
# because _fp32_shard is pinned, otherwise it's a no-op.
p._fp32_shard.to(p._fp16_shard.device, non_blocking=True)
)
if self.ssd_offload:
p._fp16_shard.copy_(p.to(p._fp16_shard.device, non_blocking=True))
else:
p._fp16_shard.copy_(
# If move_params_to_cpu is True, this will be non-blocking
# because _fp32_shard is pinned, otherwise it's a no-op.
p._fp32_shard.to(p._fp16_shard.device, non_blocking=True)
)
p.data = p._fp16_shard
torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"])
......
......@@ -31,7 +31,19 @@ import torch
from torch import Tensor
import torch.nn as nn
from fairscale.experimental.nn.ssd_offload import SsdFlatParameter
try:
from fairscale.experimental.nn.ssd_offload import (
SsdFlatParameter,
SsdFlatParameterView,
SsdFlatParameterViewProperty,
_register_property,
)
import_ssd_offload = True
except ImportError:
import_ssd_offload = False
pass
from fairscale.utils.state_dict import replace_by_prefix_
if TYPE_CHECKING:
......@@ -116,7 +128,6 @@ class FlatParameter(nn.Parameter):
# Static types.
FlatTypes = Union[FlatParameter, SsdFlatParameter]
ParamGroups = Optional[Union[List[List[nn.Parameter]], List[nn.Parameter]]]
......@@ -159,6 +170,11 @@ class FlattenParamsWrapper(nn.Module):
ssd_directory: str = "",
):
super().__init__()
if ssd_offload and not import_ssd_offload:
raise ImportError(
f"Trying to enable ssd_offload when it was not successfully imported (likely due to old torch version, current = {torch.__version__})"
)
self.ssd_offload = ssd_offload
self._fpw_module = module
self.is_flattened = False
......@@ -205,7 +221,7 @@ class FlattenParamsWrapper(nn.Module):
# support.
raise ValueError(f"Incorrect param groups {len(overall_param_set)} vs {self.num_param_managed}")
self.flat_params: List[FlatTypes] = []
self.flat_params: List[nn.Parameter] = []
# Prepare flat param names.
if flat_param_names is None:
......@@ -215,7 +231,7 @@ class FlattenParamsWrapper(nn.Module):
if len(flat_param_names) != len(set(flat_param_names)):
raise ValueError("Each flat param must be given a unique name")
self.flat_param_names = [f"flat_param_{n}" for n in flat_param_names]
flat_param: Optional[FlatTypes] = None
flat_param: Optional[nn.Parameter] = None
# Init all flat_params.
for new_p_set in self._param_sets:
......@@ -224,6 +240,7 @@ class FlattenParamsWrapper(nn.Module):
assert ssd_directory != ""
(handle, fname) = tempfile.mkstemp(dir=ssd_directory, suffix="ssd_buf_param")
flat_param = SsdFlatParameter.from_tensors(tensors=params)
flat_param.allow_unsafe_changes = True
flat_param.set_file_params(fname, 0)
else:
flat_param = FlatParameter(params, params[0].requires_grad)
......@@ -309,13 +326,13 @@ class FlattenParamsWrapper(nn.Module):
@property
def _param_infos(self) -> Iterator[Tuple[str, nn.Module, str]]:
return chain(*[p._param_infos for p in self.flat_params])
return chain(*[p._param_infos for p in self.flat_params]) # type: ignore
@property
def _shared_param_infos(self) -> Iterator[Tuple[str, str, nn.Module, str, nn.Module, str]]:
return chain(*[p._shared_param_infos for p in self.flat_params])
return chain(*[p._shared_param_infos for p in self.flat_params]) # type: ignore
def _flatten_params(self, flat_params: List[FlatTypes]) -> None:
def _flatten_params(self, flat_params: List[nn.Parameter]) -> None:
"""Flatten the managed parameters and replaced the original
attributes with views to the flat params.
"""
......@@ -331,6 +348,7 @@ class FlattenParamsWrapper(nn.Module):
# deregister the names as parameters
for _, m, n in self._param_infos:
delattr(m, n)
for _, _, m, n, _, _ in self._shared_param_infos:
delattr(m, n)
......@@ -372,8 +390,13 @@ class FlattenParamsWrapper(nn.Module):
ps = self.get_param_views()
param_views = []
for (_, m, n), p in zip(self._param_infos, ps):
setattr(m, n, p) # This will set as plain attr
param_views.append(p)
if self.ssd_offload:
assert isinstance(p, SsdFlatParameterView)
_register_property(m, n, SsdFlatParameterViewProperty(p.parent, p.id))
else:
setattr(m, n, p) # This will set as plain attr
param_views.append(p)
# Save param views for easy access if anyone still wants to access
# parameters of the module.
......@@ -498,13 +521,13 @@ class FlattenParamsWrapper(nn.Module):
gens = []
for p, data in zip(params, external_data_list):
gens.append(p.get_param_views(data))
gens.append(p.get_param_views(data)) # type: ignore
return chain(*gens)
def metadata(self, flat_param_idx: int) -> Tuple[List[str], Sequence[torch.Size], List[int]]:
"""Return metadata for a flat param given its index in the flat_params list."""
return self.flat_params[flat_param_idx].metadata()
return self.flat_params[flat_param_idx].metadata() # type: ignore
def _post_state_dict_hook(
......
......@@ -8,6 +8,7 @@ Testing SsdFlatParameter and SsdTensorHandle modules.
"""
import filecmp
import functools
import os
import tempfile
......@@ -15,11 +16,12 @@ import numpy as np
import pytest
import torch
import fairscale.experimental.nn.ssd_offload as so
from fairscale.utils import torch_version
# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
pytestmark = pytest.mark.skipif(torch_version() < (1, 11, 0), reason="requires torch version >= 1.11.0")
try:
import fairscale.experimental.nn.ssd_offload as so
except ImportError as ie:
# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
pytestmark = pytest.mark.skipif(True, reason=ie.msg)
pass
def _init():
......@@ -77,10 +79,32 @@ def test_ssd_handle_dispatch_bwd():
assert torch.equal(ssd_handle.grad, orig_copy.grad)
def test_ssd_handle_train_simple():
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
def test_ssd_handle_dispatch_bwd_hook():
_init()
def post_backward_hook(name, grad):
print(f"BACKWARD HOOK for tensor {name} CALLED")
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4), requires_grad=True)
orig_copy = orig_tensor.clone().detach().requires_grad_(True)
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
ssd_handle.set_file_params(f.name, 0)
ssd_handle.to_file(release_tensor_after_write=True)
one = torch.ones((1), requires_grad=True).cuda()
orig_copy = ssd_handle.data
cuda_copy = ssd_handle.to("cuda").detach().requires_grad_(True)
ssd_handle.data = cuda_copy
ssd_handle.register_hook(functools.partial(post_backward_hook, "ssd_handle"))
one.register_hook(functools.partial(post_backward_hook, "one"))
y1 = ssd_handle + one
y1.sum().backward()
def test_ssd_handle_train_simple():
_init()
with tempfile.NamedTemporaryFile() as f:
......@@ -92,6 +116,7 @@ def test_ssd_handle_train_simple():
orig_copy.requires_grad = True
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
ssd_handle.flush_on_dirty = False
ssd_handle.set_file_params(f.name, 0)
ssd_handle.to_file(release_tensor_after_write=True)
......@@ -102,15 +127,15 @@ def test_ssd_handle_train_simple():
y1 = ssd_handle + 1
optimizer_ssd.zero_grad()
y1.sum().backward()
assert ssd_handle.storage_state is so.StorageState.ON_CPU_CLEAN
optimizer_ssd.step()
assert ssd_handle.storage_state is so.StorageState.ON_CPU_DIRTY
y2 = orig_copy + 1
optimizer_orig.zero_grad()
y2.sum().backward()
optimizer_orig.step()
# make sure we are using the file version not the cached tensor
ssd_handle.point_to_file(f.name, 0)
assert torch.equal(ssd_handle.to_tensor(), orig_copy)
......@@ -169,9 +194,6 @@ def test_torch_save_load_ssd_flat_param_on_mem():
def test_ssd_param_train_simple():
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
_init()
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4))
......@@ -183,6 +205,7 @@ def test_ssd_param_train_simple():
ssd_param = so.SsdParameter(orig_tensor.shape, orig_tensor.dtype)
ssd_param.point_to_tensor(orig_copy)
ssd_param.flush_on_dirty = False
ssd_param.set_file_params(f.name, 0)
ssd_param.to_file(release_tensor_after_write=True)
......@@ -193,15 +216,17 @@ def test_ssd_param_train_simple():
y1 = ssd_param + 1
optimizer_ssd.zero_grad()
y1.sum().backward()
# Test to see if Dirty is being calculated correctly when optimizer modifies
# ssd_param
assert ssd_param.storage_state is so.StorageState.ON_CPU_CLEAN
optimizer_ssd.step()
assert ssd_param.storage_state is so.StorageState.ON_CPU_DIRTY
y2 = param + 1
optimizer_orig.zero_grad()
y2.sum().backward()
optimizer_orig.step()
# make sure we are using the file version not the cached tensor
ssd_param.point_to_file(f.name, 0)
assert torch.equal(ssd_param.to_tensor(), param)
......@@ -211,8 +236,186 @@ def test_ssd_flat_parameter_basic():
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32))
ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], False)
ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], direct_to_file=False)
ssd_flat_param.set_file_params(f.name, 0)
param_views = list(ssd_flat_param.get_param_views())
assert refa_param.shape == param_views[0].shape
assert refb_param.shape == param_views[1].shape
assert refc_param.shape == param_views[2].shape
assert torch.equal(refa_param, param_views[0])
assert torch.equal(refb_param, param_views[1])
assert torch.equal(refc_param, param_views[2])
ssd_flat_param.to_file()
assert not ssd_flat_param.is_available()
first_value = param_views[0][0][0].item()
assert ssd_flat_param.is_available()
assert first_value == refa_param[0][0].item()
def test_ssd_flat_parameter_view_modify():
_init()
with tempfile.NamedTemporaryFile() as f:
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=False)
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=False)
refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32), requires_grad=False)
ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], direct_to_file=False)
ssd_flat_param.set_file_params(f.name, 0)
ssd_flat_param.flush_on_dirty = False
param_views = list(ssd_flat_param.get_param_views())
assert ssd_flat_param.storage_state == so.StorageState.ON_CPU_DIRTY
ssd_flat_param.to_file()
assert ssd_flat_param.storage_state == so.StorageState.ON_DISK
assert param_views[0].tensor is None
param_views[0] += 0.1
assert ssd_flat_param.storage_state == so.StorageState.ON_CPU_DIRTY
def test_ssd_flat_parameter_view_bwd():
_init()
hooks_called = []
def post_backward_hook(name, hooks_called, *grads):
print(f"BACKWARD HOOK for tensor {name} CALLED")
hooks_called.append(name)
with tempfile.NamedTemporaryFile() as f:
refa_param = (
torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=True)
.to("cpu")
.detach()
.requires_grad_()
)
refb_param = (
torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=True)
.to("cpu")
.detach()
.requires_grad_()
)
refc_param = (
torch.nn.Parameter(torch.rand((128), dtype=torch.float32), requires_grad=True)
.to("cpu")
.detach()
.requires_grad_()
)
ssd_flat_param = so.SsdFlatParameter.from_tensors(
[refa_param, refb_param, refc_param], direct_to_file=True, filename=f.name, offset=0
)
orig_copy = ssd_flat_param.data
cuda_copy = ssd_flat_param.to("cuda").detach().requires_grad_()
cpu_copy = ssd_flat_param.to("cpu").detach().requires_grad_()
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_orig", hooks_called))
ssd_flat_param.data = cuda_copy
one = torch.ones((1), requires_grad=True, device=ssd_flat_param.device)
y1 = ssd_flat_param.views[0] + one
y2 = cuda_copy + 1
# ssd_flat_param.to_file()
# ssd_flat_param.data = orig_copy
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_cuda", hooks_called))
ssd_flat_param.views[0].register_hook(
functools.partial(post_backward_hook, "ssd_flat_param.views[0]", hooks_called)
)
ssd_flat_param.register_hook(functools.partial(post_backward_hook, "ssd_flat_param", hooks_called))
one.register_hook(functools.partial(post_backward_hook, "one", hooks_called))
y1.sum().backward()
y2.sum().backward()
assert "GradAccumulation_cuda" in hooks_called
assert "ssd_flat_param.views[0]" in hooks_called
assert "ssd_flat_param" in hooks_called
assert "one" in hooks_called
def test_ssd_flat_parameter_view_bwd_parameterization():
_init()
hooks_called = []
def post_backward_hook(name, hooks_called, *grads):
print(f"BACKWARD HOOK for tensor {name} CALLED")
hooks_called.append(name)
with tempfile.NamedTemporaryFile() as f:
layer1 = torch.nn.Linear(32, 4, bias=False)
layer2 = torch.nn.Linear(32, 4, bias=False)
layer3 = torch.nn.Linear(128, 1, bias=False)
ssd_flat_param = so.SsdFlatParameter.from_tensors(
[layer1.weight, layer2.weight, layer3.weight], direct_to_file=False, filename=f.name, offset=0
)
torch.nn.utils.parametrize.register_parametrization(
layer1, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 0)
)
torch.nn.utils.parametrize.register_parametrization(
layer2, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 1)
)
torch.nn.utils.parametrize.register_parametrization(
layer3, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 2)
)
orig_copy = ssd_flat_param.data
cuda_copy = ssd_flat_param.to("cuda").detach().requires_grad_()
cpu_copy = ssd_flat_param.to("cpu").detach().requires_grad_()
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_orig", hooks_called))
ssd_flat_param.to_file(release_tensor_after_write=False)
ssd_flat_param.data = cuda_copy
one = torch.ones(layer1.weight.shape, requires_grad=True, device=ssd_flat_param.device)
y1 = layer1.forward(one)
y2 = cuda_copy + 1
# ssd_flat_param.to_file()
# ssd_flat_param.data = orig_copy
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_cuda", hooks_called))
ssd_flat_param.views[0].register_hook(
functools.partial(post_backward_hook, "ssd_flat_param.views[0]", hooks_called)
)
ssd_flat_param.register_hook(functools.partial(post_backward_hook, "ssd_flat_param", hooks_called))
one.register_hook(functools.partial(post_backward_hook, "one", hooks_called))
y1.sum().backward()
y2.sum().backward()
assert "GradAccumulation_cuda" in hooks_called
assert "ssd_flat_param.views[0]" in hooks_called
assert "ssd_flat_param" in hooks_called
assert "one" in hooks_called
def test_ssd_flat_parameter_direct_to_file():
_init()
with tempfile.NamedTemporaryFile() as f:
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32))
ssd_flat_param = so.SsdFlatParameter.from_tensors(
[refa_param, refb_param, refc_param], direct_to_file=True, filename=f.name, offset=0
)
param_views = list(ssd_flat_param.get_param_views())
......@@ -224,3 +427,8 @@ def test_ssd_flat_parameter_basic():
assert torch.equal(refb_param, param_views[1])
assert torch.equal(refc_param, param_views[2])
ssd_flat_param.to_file()
assert not ssd_flat_param.is_available()
first_value = param_views[0][0][0].item()
assert ssd_flat_param.is_available()
assert first_value == refa_param[0][0].item()
......@@ -16,17 +16,17 @@ import torch
from torch import nn
import torch.distributed
import fairscale.experimental.nn.ssd_offload as so
try:
import fairscale.experimental.nn.ssd_offload as so
except ImportError as ie:
# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
pytestmark = pytest.mark.skipif(True, reason=ie.msg)
pass
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes
# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
print(f"torch version {torch_version()}")
pytestmark = pytest.mark.skipif(torch_version() < (1, 11, 0), reason="requires torch version >= 1.11.0")
# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# All helper functions called by spawn must be either @classmethod, @staticmethod
......@@ -137,8 +137,6 @@ def rename_test(testcase_func, param_num, param):
class TestSsdMemory(DistributedTest):
def test_memory_benchmark(self):
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
test_fn = functools.partial(self._test_memory_benchmark, config={})
spawn_and_init(test_fn)
......@@ -218,8 +216,6 @@ class TimeKeeper:
class TestModuleProperties(DistributedTest):
@parameterized.expand(CONFIG, name_func=rename_test)
def test_named_parameters(self, config):
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
test_fn = functools.partial(self._test_named_params, config=config)
spawn_and_init(test_fn)
......@@ -264,23 +260,17 @@ class TestModuleProperties(DistributedTest):
class TestSsdLoading(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_ssd_offloading_eval(self, config):
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
test_fn = functools.partial(self._test_ssd_offload_eval, config=config)
spawn_and_init(test_fn)
@parameterized.expand(CONFIG, name_func=rename_test)
def test_transformer_parameterized(self, config):
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
spawn_and_init(functools.partial(self._test_identical_outputs_eval, TransformerWithSharedParams, config))
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_ssd_offloading_train_flatten_params_wrapper(self, config):
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
test_fn = functools.partial(self._test_ssd_offloading_train_flatten_params_wrapper, config=config)
spawn_and_init(test_fn)
......@@ -288,6 +278,8 @@ class TestSsdLoading(DistributedTest):
@classmethod
def _test_ssd_offloading_train_flatten_params_wrapper(self, rank, group, config):
SIZE = 16 * 16
LR = 0.01
MOMENTUM = 0.1
model = SimpleLinear(group, input_size=SIZE, output_size=SIZE, layers=4)
with tempfile.TemporaryDirectory() as current_tempdir:
......@@ -305,7 +297,7 @@ class TestSsdLoading(DistributedTest):
model = FullyShardedDataParallel(model, **config)
model_device = torch.device("cuda")
model.train()
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optim = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
checkpoint_file = tempfile.NamedTemporaryFile()
checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir")
......@@ -322,9 +314,13 @@ class TestSsdLoading(DistributedTest):
input = model.get_input(torch.device("cuda"))
output = model(*input)
pre_checkpoint_last_output = output
"""
param_itr = iter(model.named_parameters())
p_name, p_val = next(param_itr)
print(f"i={i} pre_checkpoint {p_name} = {p_val[0].item()}")
"""
loss = model.module.get_loss(input, output).to(model_device)
assert loss.dtype == torch.float32
model.module.run_backward(loss)
optim.step()
if i == 0:
......@@ -332,18 +328,23 @@ class TestSsdLoading(DistributedTest):
# so.torch_saver.save({"model": model.state_dict(), "optim": optim.state_dict()}, checkpoint_file.name)
torch.save({"model": model.state_dict()}, checkpoint_file.name)
# reset momentum just after checkpoint save
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optim = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
checkpoint = torch.load(checkpoint_file.name)
model.load_state_dict(checkpoint["model"])
# reset momentum just after checkpoint load
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optim = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
# do more iterations after loading checkpoint
for i in range(ITERATIONS - 1):
optim.zero_grad()
input = model.get_input(torch.device("cuda"))
output = model(*input)
post_checkpoint_last_output = output
"""
param_itr = iter(model.named_parameters())
p_name, p_val = next(param_itr)
print(f"i={i} post_checkpoint {p_name} = {p_val[0].item()}")
"""
loss = model.module.get_loss(input, output).to(model_device)
assert loss.dtype == torch.float32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment