Unverified Commit 51b53ddb authored by Paul Johnson's avatar Paul Johnson Committed by GitHub
Browse files

[FSDP] ssd_offload fixing backward path (grad_fn) for SsdFlatParameter and...

[FSDP] ssd_offload fixing backward path (grad_fn) for SsdFlatParameter and SsdFlatParameterView (#974)

* [FSDP] fixing backward path for SsdFlatParameter and SsdFlatParameterView when overriding .data

* Get ssd_offload unit tests passing

* [FSDP] get all test_fsdp_offload tests passing w/ ssd_offload on

* Update changelog
parent 58ccb166
...@@ -12,7 +12,40 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -12,7 +12,40 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- FSDP: Add pickle/unpickle support for SsdTensorHandle (and derived classes), - FSDP: Add pickle/unpickle support for SsdTensorHandle (and derived classes),
verified that FSDP models w/ ssd_offload enabled can correctly call model.state_dict() verified that FSDP models w/ ssd_offload enabled can correctly call model.state_dict()
and model.load_state_dict(...) and thus successfully checkpoint and recover parameters and model.load_state_dict(...) and thus successfully checkpoint and recover parameters
stored as SsdFlatParameters. stored as SsdFlatParameters. [#964]
- FSDP Ssd Offload: [#974]
* add __setattr__ function for SsdTensorHandle/SsdFlatParameter to capture when .data is overridden
and perform necessary checks/updates before letting the tensor metadata be updated.
* There was a bug in pytorch core that caused re-assigning TensorSubclass .data to
disable the __torch_dispatch__ mechanism and effectively revert back to a normal
tensor. Since ssd_offload feature uses __torch_dispatch__ extensively and FSDP overrides
.data, ssd_offload now can only be used if pytorch version is 1.12.0 or later
(currently pytorch-nightly release). pytest tests are disabled and trying to import
ssd_offload or use FSDP with ssd_offload enabled will raise ImportError Exceptions.
* Enhance storage_state ON_CPU into ON_CPU_CLEAN and ON_CPU_DIRTY. ON_CPU_CLEAN
indicates that the .tensor value and the value stored on disk are identical, and no
writes are needed when calling .to_file(). ON_CPU_DIRTY indicates .tensor does not
match value stored on disk, and a write is necessary. This is disabled in FSDP as
it does not currently flush variables to disk as soon as optimizer modifies values.
* Fix detection in SsdTensorHandle __torch_dispatch__ when it is used in an in-place
operator, or as an output, then calling mark_dirty().
* Fix grad_fn on SsdFlatParameterViews so that both ssd_flat_param.view.grad_fn points
to a split and view on ssd_flat_parameter so that when X.backwards() is called. ssd_flat_param.grad
is properly updated in the backward pass.
* Add unit tests to verify grad_fn and backward hooks are called appropriately on
SsdFlatParameters/SsdFlatParameterViews
* Implement SsdFlatParameter.from_tensor() direct_to_file option. This allows creating
and SsdFlatParameter and directly writing it to disk, rather than creating a new tensor
first. This prevents another copy of tensors to be instantiated and can save memory when
creating very large SsdFlatParameters.
* Add SsdFlatParameterViewProperty for overriding a parameter in an nn.Module
similar to pytorch core's parameterization code path. This allows SsdFlatParameterViews
to be updated as SsdFlatParameter.data is overridden by replacing Layer.weight with a
property whose getter returns ssd_flat_param.views[view_id].
* Added example SsdFlatParameterViewParameterization on how the existing parameterization
method could be used instead of SsdFlatParameterViewProperty. But this method will result
in memory inefficiencies due to parameterization always keeping a copy of the original
parameter internally.
### Fixed ### Fixed
......
...@@ -11,18 +11,24 @@ import io ...@@ -11,18 +11,24 @@ import io
import os import os
import pickle import pickle
from types import TracebackType from types import TracebackType
from typing import IO, Any, BinaryIO, Iterator, List, Optional, Sequence, Tuple, Type, Union from typing import IO, Any, BinaryIO, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Type, Union
import numpy as np import numpy as np
import torch import torch
from torch.serialization import DEFAULT_PROTOCOL as DEFAULT_PROTOCOL from torch.serialization import DEFAULT_PROTOCOL as DEFAULT_PROTOCOL
from fairscale.utils import torch_version
try: try:
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
except ImportError: except ImportError:
# The PyTorch version(<1.9) we test with does not support the tree_map API. # The PyTorch version(<1.9) we test with does not support the tree_map API.
pass pass
if torch_version() < (1, 12, 0):
raise ImportError(
f"ssd_offload only works on torch versions 1.12.0 and beyond, but torch version is: {torch.__version__}"
)
DEFAULT_CHUNK_SIZE = 2048 * 2048 DEFAULT_CHUNK_SIZE = 2048 * 2048
...@@ -69,7 +75,10 @@ def read(input_tensor: torch.Tensor, filename: str, file_offset_bytes: int = 0) ...@@ -69,7 +75,10 @@ def read(input_tensor: torch.Tensor, filename: str, file_offset_bytes: int = 0)
chunk_start = i * chunk_size_bytes chunk_start = i * chunk_size_bytes
chunk_end = min(size_in_bytes, chunk_start + chunk_size_bytes) chunk_end = min(size_in_bytes, chunk_start + chunk_size_bytes)
data_read = f.readinto(input_tensor_mv[chunk_start:chunk_end]) data_read = f.readinto(input_tensor_mv[chunk_start:chunk_end])
assert data_read == chunk_end - chunk_start if data_read != chunk_end - chunk_start:
raise RuntimeError(
f"Attempted to read {chunk_end - chunk_start} more bytes from {filename}, but only read: {data_read} bytes. Total Bytes read = {chunk_start + data_read}, total bytes expected: {size_in_bytes}"
)
class StorageState(Enum): class StorageState(Enum):
...@@ -82,7 +91,8 @@ class StorageState(Enum): ...@@ -82,7 +91,8 @@ class StorageState(Enum):
UNALLOCATED = auto() UNALLOCATED = auto()
ON_DISK = auto() ON_DISK = auto()
ON_CPU = auto() ON_CPU_CLEAN = auto()
ON_CPU_DIRTY = auto()
class SsdTensorHandle(torch.Tensor): class SsdTensorHandle(torch.Tensor):
...@@ -109,12 +119,26 @@ class SsdTensorHandle(torch.Tensor): ...@@ -109,12 +119,26 @@ class SsdTensorHandle(torch.Tensor):
@staticmethod @staticmethod
def __new__( def __new__(
cls: Type[SsdTensorHandle], shape: torch.Size, dtype: torch.dtype, requires_grad: bool = False cls: Type[SsdTensorHandle],
shape: torch.Size,
dtype: torch.dtype,
requires_grad: bool = False,
device: torch.device = torch.device("cpu"),
flush_on_dirty: bool = True,
allow_unsafe_changes: bool = False,
) -> SsdTensorHandle: ) -> SsdTensorHandle:
r = super(SsdTensorHandle, cls)._make_wrapper_subclass(cls, shape, dtype=dtype, requires_grad=requires_grad) # type: ignore r = super(SsdTensorHandle, cls)._make_wrapper_subclass(cls, shape, dtype=dtype, requires_grad=requires_grad, device=device) # type: ignore
return r return r
def __init__(self, shape: torch.Size, dtype: torch.dtype, requires_grad: bool) -> None: def __init__(
self,
shape: torch.Size,
dtype: torch.dtype,
requires_grad: bool,
device: torch.device = torch.device("cpu"),
flush_on_dirty: bool = True,
allow_unsafe_changes: bool = False,
) -> None:
self._unpickle_f: Optional[Union[BinaryIO, IO[bytes]]] = None self._unpickle_f: Optional[Union[BinaryIO, IO[bytes]]] = None
self._shape = shape self._shape = shape
...@@ -128,8 +152,17 @@ class SsdTensorHandle(torch.Tensor): ...@@ -128,8 +152,17 @@ class SsdTensorHandle(torch.Tensor):
self.offset = -1 self.offset = -1
# valid if loaded to memory # valid if loaded to memory
self.tensor: Optional[torch.Tensor] = None self.tensor: Optional[torch.Tensor] = None
self.requires_grad = requires_grad
self.storage_state = StorageState.UNALLOCATED self.storage_state = StorageState.UNALLOCATED
self.flush_on_dirty = flush_on_dirty
self.allow_unsafe_changes = allow_unsafe_changes
def mark_dirty(self) -> None:
assert self.tensor is not None
assert self.storage_state in [StorageState.ON_CPU_CLEAN, StorageState.ON_CPU_DIRTY]
self.storage_state = StorageState.ON_CPU_DIRTY
# hack to force write on mark_dirty
if self.flush_on_dirty:
self.to_file()
@classmethod @classmethod
def from_file( def from_file(
...@@ -143,7 +176,7 @@ class SsdTensorHandle(torch.Tensor): ...@@ -143,7 +176,7 @@ class SsdTensorHandle(torch.Tensor):
@classmethod @classmethod
def from_tensor(cls: Type[SsdTensorHandle], tensor: torch.Tensor) -> SsdTensorHandle: def from_tensor(cls: Type[SsdTensorHandle], tensor: torch.Tensor) -> SsdTensorHandle:
"""Returns a new SsdTensorHandle from a tensor.""" """Returns a new SsdTensorHandle from a tensor."""
handle = cls(shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad) handle = cls(shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad, device=tensor.device)
handle.point_to_tensor(tensor) handle.point_to_tensor(tensor)
return handle return handle
...@@ -165,10 +198,11 @@ class SsdTensorHandle(torch.Tensor): ...@@ -165,10 +198,11 @@ class SsdTensorHandle(torch.Tensor):
def point_to_tensor(self, tensor: torch.Tensor) -> None: def point_to_tensor(self, tensor: torch.Tensor) -> None:
assert self.tensor is None assert self.tensor is None
assert self._shape == tensor.shape if not self.allow_unsafe_changes:
assert self._shape == tensor.shape
assert self._dtype == tensor.dtype assert self._dtype == tensor.dtype
self.tensor = tensor self.tensor = tensor
self.storage_state = StorageState.ON_CPU self.storage_state = StorageState.ON_CPU_DIRTY
# if resizing a handle that is part of an ssd buffer, care must be taken that the new size # if resizing a handle that is part of an ssd buffer, care must be taken that the new size
# doesn't conflict with adjacent handles! # doesn't conflict with adjacent handles!
...@@ -185,21 +219,33 @@ class SsdTensorHandle(torch.Tensor): ...@@ -185,21 +219,33 @@ class SsdTensorHandle(torch.Tensor):
if self.tensor is not None: if self.tensor is not None:
return self.tensor return self.tensor
else: else:
result_tensor = torch.empty(size=self._shape, dtype=self._dtype, requires_grad=self.requires_grad) if self.device != torch.device("cpu"):
raise RuntimeError(
f"to_tensor called on an SsdTensorHandle when the tensor has been offloaded to disk. self.device = {self.device}, it should be {torch.device('cpu')}. Some unexpected .data override has occured!!"
)
result_tensor = torch.empty(size=self.shape, dtype=self.dtype, requires_grad=self.requires_grad)
self.copy_into_tensor(result_tensor) self.copy_into_tensor(result_tensor)
self.tensor = result_tensor self.tensor = result_tensor
self.storage_state = StorageState.ON_CPU self.storage_state = StorageState.ON_CPU_CLEAN
return self.tensor return self.tensor
def to_file(self, permit_when_tensor_none: bool = False, release_tensor_after_write: bool = True) -> None: def to_file(self, permit_when_tensor_none: bool = False, release_tensor_after_write: bool = True) -> None:
"""Saves the tensor to disk and releases memory if specified.""" """Saves the tensor to disk and releases memory if specified."""
assert self.tensor is not None or permit_when_tensor_none assert self.tensor is not None or permit_when_tensor_none
# if it's available in Memory but not modified, no need to write-back
if self.tensor is not None: if self.tensor is not None:
write(self.tensor, self.filename, self.offset * self.tensor.element_size()) if self.storage_state is StorageState.ON_CPU_DIRTY:
if self.device != torch.device("cpu"):
raise RuntimeError(
f"to_file called on an SsdTensorHandle when self.device = {self.device}, it should be {torch.device('cpu')}. Some unexpected .data override has occured!!"
)
write(self.tensor, self.filename, self.offset * self.tensor.element_size())
if release_tensor_after_write: if release_tensor_after_write:
self.tensor = None self.tensor = None
self.storage_state = StorageState.ON_DISK self.storage_state = StorageState.ON_DISK
else:
self.storage_state = StorageState.ON_CPU_CLEAN
def copy_into_tensor(self, tensor: torch.Tensor) -> None: def copy_into_tensor(self, tensor: torch.Tensor) -> None:
"""Copies SsdTensorHandle's data into the given tensor. """Copies SsdTensorHandle's data into the given tensor.
...@@ -211,7 +257,9 @@ class SsdTensorHandle(torch.Tensor): ...@@ -211,7 +257,9 @@ class SsdTensorHandle(torch.Tensor):
function. This can be useful for calls like named_parameters() when function. This can be useful for calls like named_parameters() when
the tensor is already offloaded to disk. the tensor is already offloaded to disk.
""" """
assert self._shape == tensor.shape # ideally this should be checked but .data shenanigans forces it to
# be disabled due to the way FSDP shards parameters
# assert self._shape == tensor.shape
assert self._dtype == tensor.dtype assert self._dtype == tensor.dtype
if self.tensor is not None: if self.tensor is not None:
tensor.copy_(self.tensor) tensor.copy_(self.tensor)
...@@ -229,25 +277,48 @@ class SsdTensorHandle(torch.Tensor): ...@@ -229,25 +277,48 @@ class SsdTensorHandle(torch.Tensor):
versions to track if modifications have been made. If we detect changes to the versions to track if modifications have been made. If we detect changes to the
tensor, we write it to the file maintained by the Handle. tensor, we write it to the file maintained by the Handle.
""" """
func_name = func.overloadpacket.__name__
ssd_tensor_handles = [] ssd_tensor_handles = []
def unwrap(e: Any) -> torch.Tensor: def unwrap(e: Any) -> torch.Tensor:
if isinstance(e, SsdTensorHandle): if isinstance(e, SsdTensorHandle):
t = e.to_tensor() t = e.to_tensor()
ssd_tensor_handles.append((e, t._version)) # type: ignore ssd_tensor_handles.append(e)
return t return t
else: else:
return e return e
r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
for e, saved_version in ssd_tensor_handles: for e in ssd_tensor_handles:
inplace_is_this_tensor = func.__name__[-1] == "_" and e is args[0] inplace_is_this_tensor = (
(func_name.endswith("_") and not func_name.endswith("__")) or func_name.startswith("__i")
) and e is args[0]
out_is_this_tensor = False if "out" not in kwargs else e is kwargs["out"] out_is_this_tensor = False if "out" not in kwargs else e is kwargs["out"]
if inplace_is_this_tensor or out_is_this_tensor: if inplace_is_this_tensor or out_is_this_tensor:
e.to_file() e.mark_dirty()
return r return r
def __setattr__(self, name: str, value: Any) -> None:
if name == "data":
assert isinstance(value, torch.Tensor)
if not self.allow_unsafe_changes:
# Respect .data changes, and the user better know what they are doing!
if self.storage_state == StorageState.ON_CPU_DIRTY:
raise RuntimeError(
"Attempting to override tensor when the existing tensor is dirty, this is an error!"
)
if value.shape != self.shape:
raise RuntimeError(
f"Attempting to override tensor metadata using .data to change shape of tensor. Orig shape: {self.shape} New shape: {value.shape}"
)
if value.requires_grad != self.requires_grad:
raise RuntimeError(
f"Attempting to override tensor metadata using .data to change requires_grad. Orig value: {self.requires_grad} New value: {value.requires_grad}"
)
self.tensor = value
super(SsdTensorHandle, self).__setattr__(name, value)
@classmethod @classmethod
def __unpickle__( def __unpickle__(
cls: Type[SsdTensorHandle], shape: torch.Size, dtype: torch.dtype, requires_grad: bool, filename: str cls: Type[SsdTensorHandle], shape: torch.Size, dtype: torch.dtype, requires_grad: bool, filename: str
...@@ -264,7 +335,7 @@ class SsdTensorHandle(torch.Tensor): ...@@ -264,7 +335,7 @@ class SsdTensorHandle(torch.Tensor):
head, tail = os.path.split(self.filename) head, tail = os.path.split(self.filename)
filename = os.path.join(self.override_directory_path, tail) filename = os.path.join(self.override_directory_path, tail)
if self.is_available(): if self.is_available():
byte_iter = iter(TensorChunkingIterator(self.tensor)) byte_iter = iter(TensorChunkingIterator(self.tensor)) # ignore: type
else: else:
byte_iter = iter( byte_iter = iter(
FileChunkingIterator(self.filename, expected_size_bytes=self.numel() * self.element_size()) FileChunkingIterator(self.filename, expected_size_bytes=self.numel() * self.element_size())
...@@ -358,19 +429,29 @@ class TorchSaver: ...@@ -358,19 +429,29 @@ class TorchSaver:
class SsdParameter(SsdTensorHandle, torch.nn.Parameter): class SsdParameter(SsdTensorHandle, torch.nn.Parameter):
@classmethod @classmethod
def from_tensor(cls: Type[SsdParameter], tensor: SsdTensorHandle) -> SsdParameter: # type: ignore def from_tensor(cls: Type[SsdParameter], tensor: SsdTensorHandle) -> SsdParameter: # type: ignore
r = cls(tensor.shape, tensor.dtype, tensor.requires_grad) r = cls(tensor.shape, tensor.dtype, tensor.requires_grad, device=tensor.device)
r.point_to_tensor(tensor) r.point_to_tensor(tensor)
return r return r
@staticmethod @staticmethod
def __new__( def __new__(
cls: Type[SsdParameter], shape: torch.Size, dtype: torch.dtype, requires_grad: bool = True cls: Type[SsdParameter],
shape: torch.Size,
dtype: torch.dtype,
requires_grad: bool = True,
device: torch.device = torch.device("cpu"),
) -> SsdParameter: ) -> SsdParameter:
r = super(SsdParameter, cls).__new__(cls, shape, dtype=dtype, requires_grad=requires_grad) r = super(SsdParameter, cls).__new__(cls, shape=shape, dtype=dtype, requires_grad=requires_grad, device=device)
return r # type: ignore return r # type: ignore
def __init__(self, shape: torch.Size, dtype: torch.dtype, requires_grad: bool = True) -> None: def __init__(
super(SsdParameter, self).__init__(shape, dtype, requires_grad) self,
shape: torch.Size,
dtype: torch.dtype,
requires_grad: bool = True,
device: torch.device = torch.device("cpu"),
) -> None:
super(SsdParameter, self).__init__(shape=shape, dtype=dtype, requires_grad=requires_grad, device=device)
class SsdFlatParameter(SsdParameter): class SsdFlatParameter(SsdParameter):
...@@ -381,7 +462,11 @@ class SsdFlatParameter(SsdParameter): ...@@ -381,7 +462,11 @@ class SsdFlatParameter(SsdParameter):
""" """
def __new__( def __new__(
cls: Type[SsdFlatParameter], shapes: Sequence[torch.Size], dtype: torch.dtype, requires_grad: bool = True cls: Type[SsdFlatParameter],
shapes: Sequence[torch.Size],
dtype: torch.dtype,
requires_grad: bool = True,
device: torch.device = torch.device("cpu"),
) -> SsdFlatParameter: ) -> SsdFlatParameter:
"""Make an object using the parent's __new__ function.""" """Make an object using the parent's __new__ function."""
...@@ -390,10 +475,18 @@ class SsdFlatParameter(SsdParameter): ...@@ -390,10 +475,18 @@ class SsdFlatParameter(SsdParameter):
raise ValueError("An non-empty list or tuple argument is needed") raise ValueError("An non-empty list or tuple argument is needed")
size = sum([np.prod(s) for s in shapes]) size = sum([np.prod(s) for s in shapes])
r = super(SsdFlatParameter, cls).__new__(cls, torch.Size((size,)), dtype=dtype, requires_grad=requires_grad) r = super(SsdFlatParameter, cls).__new__(
cls, torch.Size((size,)), dtype=dtype, requires_grad=requires_grad, device=device
)
return r # type: ignore return r # type: ignore
def __init__(self, shapes: Sequence[torch.Size], dtype: torch.dtype, requires_grad: bool = True): def __init__(
self,
shapes: Sequence[torch.Size],
dtype: torch.dtype,
requires_grad: bool = True,
device: torch.device = torch.device("cpu"),
):
"""Initialize the _param_numels and _param_shapes lists.""" """Initialize the _param_numels and _param_shapes lists."""
self._param_shapes = shapes self._param_shapes = shapes
self._param_numels = [np.prod(s) for s in shapes] self._param_numels = [np.prod(s) for s in shapes]
...@@ -402,6 +495,7 @@ class SsdFlatParameter(SsdParameter): ...@@ -402,6 +495,7 @@ class SsdFlatParameter(SsdParameter):
self.numel() <= total_numels self.numel() <= total_numels
), f"Something wrong with __new__ method, {self.numel()} vs. {sum(self._param_numels)}" ), f"Something wrong with __new__ method, {self.numel()} vs. {sum(self._param_numels)}"
self.views: List[SsdFlatParameterView] = []
# These are set by FPW class below, not by this class itself. # These are set by FPW class below, not by this class itself.
self._param_infos: List[Tuple[str, torch.nn.Module, str]] = [] self._param_infos: List[Tuple[str, torch.nn.Module, str]] = []
self._shared_param_infos: List[Tuple[str, str, torch.nn.Module, str, torch.nn.Module, str]] = [] self._shared_param_infos: List[Tuple[str, str, torch.nn.Module, str, torch.nn.Module, str]] = []
...@@ -410,6 +504,31 @@ class SsdFlatParameter(SsdParameter): ...@@ -410,6 +504,31 @@ class SsdFlatParameter(SsdParameter):
shape=torch.Size((total_numels,)), dtype=dtype, requires_grad=requires_grad shape=torch.Size((total_numels,)), dtype=dtype, requires_grad=requires_grad
) )
def __setattr__(self, name: str, value: Any) -> None:
super(SsdFlatParameter, self).__setattr__(name, value)
if name == "data":
# if .data has changed, we need to totally destroy any existing views because things
# like device might have changed. It won't destroy any pointers to those views outside
# of here, however resetting self.views will trigger the old view's assertion in
# __torch_dispatch__ that it is the current view of it's parent object
self.views = []
self._refresh_views()
def _invalidate_views(self) -> None:
for v in self.views:
v.tensor = None
@torch.enable_grad()
def _refresh_views(self) -> None:
if self._shape != self.shape:
self.views = []
return
if len(self.views) == 0:
self.views = [s.view(v) for s, v in zip(self.split(self._param_numels), self._param_shapes)] # type: ignore
else:
for v, t, s in zip(self.views, self.tensor.split(self._param_numels), self._param_shapes):
v.tensor = t.view(s)
def get_param_views(self, external_data: Optional[torch.Tensor] = None) -> Iterator[torch.Tensor]: def get_param_views(self, external_data: Optional[torch.Tensor] = None) -> Iterator[torch.Tensor]:
"""Return a generator of views that map to the original parameters.""" """Return a generator of views that map to the original parameters."""
# Note, self.data could be sharded, so its numel is <= to the sum. # Note, self.data could be sharded, so its numel is <= to the sum.
...@@ -425,7 +544,15 @@ class SsdFlatParameter(SsdParameter): ...@@ -425,7 +544,15 @@ class SsdFlatParameter(SsdParameter):
) )
return (t.view(s) for (t, s) in zip(external_data.split(self._param_numels), self._param_shapes)) return (t.view(s) for (t, s) in zip(external_data.split(self._param_numels), self._param_shapes))
else: else:
return (t.view(s) for (t, s) in zip(self.split(self._param_numels), self._param_shapes)) # this needs to return SsdFlatParameterViews
if not self.is_available():
self.to_tensor()
if len(self.views) == 0:
raise RuntimeError(
"Trying to call get_param_views when self.views is empty, this means that .data games have been played and the current .data shape doesn't match the constructed shape."
)
return (v for v in self.views)
def metadata(self) -> Tuple[List[str], Sequence[torch.Size], List[int]]: def metadata(self) -> Tuple[List[str], Sequence[torch.Size], List[int]]:
"""Return tuple of (names, shapes, numels) metadata for this flat parameter.""" """Return tuple of (names, shapes, numels) metadata for this flat parameter."""
...@@ -434,7 +561,11 @@ class SsdFlatParameter(SsdParameter): ...@@ -434,7 +561,11 @@ class SsdFlatParameter(SsdParameter):
@classmethod @classmethod
def from_tensors( def from_tensors(
cls: Type[SsdFlatParameter], tensors: Sequence[torch.Tensor], direct_to_file: bool = False cls: Type[SsdFlatParameter],
tensors: Sequence[torch.Tensor],
direct_to_file: bool = False,
filename: str = "",
offset: int = 0,
) -> "SsdFlatParameter": ) -> "SsdFlatParameter":
"""Returns a new SsdFlatParameter from a sequence of tensors.""" """Returns a new SsdFlatParameter from a sequence of tensors."""
assert ( assert (
...@@ -449,17 +580,75 @@ class SsdFlatParameter(SsdParameter): ...@@ -449,17 +580,75 @@ class SsdFlatParameter(SsdParameter):
if any(isinstance(t, SsdFlatParameter) for t in tensors): if any(isinstance(t, SsdFlatParameter) for t in tensors):
raise ValueError("Nesting SsdFlatParameter is not supported") raise ValueError("Nesting SsdFlatParameter is not supported")
handle = cls(shapes=[t.size() for t in tensors], dtype=tensors[0].dtype, requires_grad=tensors[0].requires_grad) requires_grad = tensors[0].requires_grad
dtype = tensors[0].dtype
device = tensors[0].device
for t in tensors:
if t.requires_grad != requires_grad:
raise RuntimeError("Not all tensors have identical requires_grad option")
if t.dtype != dtype:
raise RuntimeError("Not all tensors have identical dtype option")
if t.device != device:
raise RuntimeError("Not all tensors have identical device option")
handle = cls(
shapes=[t.size() for t in tensors],
dtype=tensors[0].dtype,
requires_grad=tensors[0].requires_grad,
device=device,
)
handle.set_file_params(filename, offset)
if direct_to_file: if direct_to_file:
assert False, "direct_to_file not implemented yet" assert filename != ""
pass offset = offset
for t in tensors:
write(t, handle.filename, offset)
offset += t.numel() * t.element_size()
handle.storage_state = StorageState.ON_DISK
else: else:
tensor = torch.cat( tensor = torch.cat(
[t.detach().reshape(-1) if isinstance(t, torch.nn.Parameter) else t.reshape(-1) for t in tensors], 0 [t.reshape(-1) if isinstance(t, torch.nn.Parameter) else t.reshape(-1) for t in tensors],
) 0,
).detach()
tensor.requires_grad_()
handle.point_to_tensor(tensor) handle.point_to_tensor(tensor)
return handle return handle
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore
func_name = func.overloadpacket.__name__
r = super(SsdFlatParameter, cls).__torch_dispatch__(func, types, args, kwargs) # type: ignore
if func_name.startswith("split"):
assert isinstance(args[0], SsdFlatParameter)
parent = args[0]
return [SsdFlatParameterView(parent, t, idx) for idx, t in enumerate(r)]
else:
return r
# need to subclass these methods to support Views
def point_to_tensor(self, tensor: torch.Tensor) -> None:
super(SsdFlatParameter, self).point_to_tensor(tensor)
self._refresh_views()
def point_to_file(self, filename: str, offset: int) -> None:
super(SsdFlatParameter, self).point_to_file(filename, offset)
self._invalidate_views()
def to_tensor(self) -> torch.Tensor:
call_refresh_views = False
if self.tensor is None:
call_refresh_views = True
result = super(SsdFlatParameter, self).to_tensor()
if call_refresh_views:
self._refresh_views()
return result
def to_file(self, permit_when_tensor_none: bool = False, release_tensor_after_write: bool = True) -> None:
super(SsdFlatParameter, self).to_file(permit_when_tensor_none, release_tensor_after_write)
self._invalidate_views()
@classmethod @classmethod
def __unpickle_SFP__( def __unpickle_SFP__(
cls: Type[SsdFlatParameter], cls: Type[SsdFlatParameter],
...@@ -494,6 +683,194 @@ class SsdFlatParameter(SsdParameter): ...@@ -494,6 +683,194 @@ class SsdFlatParameter(SsdParameter):
) )
class SsdFlatParameterView(torch.Tensor):
"""
Represents a view into an SsdFlatParameter. It is needed due to FSDP's usage of flattening parameters.
"""
def __new__(
cls: Type[SsdFlatParameterView], parent: SsdFlatParameter, tensor: torch.Tensor, id: int
) -> SsdFlatParameterView:
r = super(SsdFlatParameterView, cls)._make_wrapper_subclass(cls, tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad, device=tensor.device) # type: ignore
return r
def __init__(self: SsdFlatParameterView, parent: SsdFlatParameter, tensor: torch.Tensor, id: int) -> None:
self.parent = parent
self.tensor: Optional[torch.Tensor] = tensor
self.id = id
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore
"""Intercepts all operations performed on this handle object.
Before any operation, the tensor attribute is unwrapped from the handle
and used in the operation. We maintain a refernce to the tensor and its current
versions to track if modifications have been made. If we detect changes to the
tensor, we write it to the file maintained by the Handle.
"""
func_name = func.overloadpacket.__name__
ssd_tensor_handles = []
def unwrap(e: Any) -> torch.Tensor:
if isinstance(e, SsdFlatParameterView):
if not e.parent.is_available():
e.parent.to_tensor()
# first condition is to take care of the case where we are first constructing e.parent.views as a list comprehension which hasn't
# completed yet
if len(e.parent.views) != 0 and e is not e.parent.views[e.id]:
raise RuntimeError(
"This view should no longer be used as the parent object has had it's .data overwritten (e.parent.views[e.id])!!!"
)
# e.parent will ensure that e.tensor is valid and points to tensor view
t = e.tensor
ssd_tensor_handles.append(e)
return t
else:
return e
r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
for e in ssd_tensor_handles:
inplace_is_this_tensor = (
(func_name.endswith("_") and not func_name.endswith("__")) or func_name.startswith("__i")
) and e is args[0]
out_is_this_tensor = False if "out" not in kwargs else e is kwargs["out"]
if inplace_is_this_tensor or out_is_this_tensor:
e.parent.mark_dirty()
if func_name.startswith("view"):
assert isinstance(args[0], SsdFlatParameterView)
flat_view = args[0]
return SsdFlatParameterView(flat_view.parent, r, flat_view.id)
return r
# ###################################
# ### BEGIN OVERRIDE_PROPERTY FNs ###
# ###################################
# This code is taken mostly from pytorch core parameterization
# pytorch/torch/nn/utils/parametrize.py
def _inject_new_class(module: torch.nn.Module) -> None:
r"""Sets up a module to be parametrized.
This works by substituting the class of the module by a class
that extends it to be able to inject a property
Args:
module (nn.Module): module into which to inject the property
"""
cls = module.__class__
def getstate(self): # type: ignore
raise RuntimeError(
"Serialization of parametrized modules is only "
"supported through state_dict(). See:\n"
"https://pytorch.org/tutorials/beginner/saving_loading_models.html"
"#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training"
)
param_cls = type(
f"Parametrized{cls.__name__}",
(cls,),
{
"__getstate__": getstate,
},
)
module.__class__ = param_cls
module.override_properties: Dict[str, Callable[[], torch.Tensor]] = {} # type: ignore
# setattr(module, "override_properties", {})
def _inject_property(module: torch.nn.Module, property_name: str) -> None:
r"""Injects a property into module[property_name].
It assumes that the class in the module has already been modified from its
original one using _inject_new_class and that the tensor under :attr:`property_name`
has already been moved out
Args:
module (nn.Module): module into which to inject the property
property_name (str): name of the name of the property to create
"""
def get_parametrized(self: torch.nn.Module) -> torch.Tensor:
prop: Callable[[], torch.Tensor] = self.override_properties[property_name] # type: ignore
# If caching is not active, this function just evaluates the parameterization
return prop()
def set_original(self: torch.nn.Module, value: Callable[[], torch.Tensor]) -> None:
self.override_properties[property_name] = value # type: ignore
def del_fn(self: torch.nn.Module) -> None:
_remove_property(self, property_name)
setattr(module.__class__, property_name, property(get_parametrized, set_original, del_fn))
def _register_property(module: torch.nn.Module, property_name: str, property_value: Callable[[], torch.Tensor]) -> None:
has_injected_class = hasattr(module, "override_properties")
if not has_injected_class:
_inject_new_class(module)
if hasattr(module, property_name):
delattr(module, property_name)
module.override_properties[property_name] = property_value # type: ignore
_inject_property(module, property_name)
def _remove_property(module: torch.nn.Module, property_name: str, new_property_value: Optional[Any] = None) -> None:
delattr(module.__class__, property_name)
del module.override_properties[property_name] # type: ignore
# Roll back the parametrized class if no other buffer or parameter
# is currently parametrized in this class
if len(module.override_properties) == 0: # type: ignore
delattr(module, "override_properties")
# Restore class
orig_cls = module.__class__.__bases__[0]
module.__class__ = orig_cls
if new_property_value is not None:
setattr(module.__class__, property_name, new_property_value)
# #################################
# ### END OVERRIDE_PROPERTY FNs ###
# #################################
class SsdFlatParameterViewProperty:
"""
Allows for a mutable view to replace a layer's trainable parameters.
This is needed since FSDP is changing .data under the covers,
SsdFlatParameter cannot just rely on this since each view (of type SsdFlatParameterView) has
an internal representation. So every time we access a view, we need to
make sure we get the up-to-date version, and not the original version
when flattening the parameters.
"""
def __init__(self, parent: SsdFlatParameter, view_id: int) -> None:
super().__init__()
self.parent = parent
self.view_id = view_id
def __call__(self) -> SsdFlatParameterView:
return self.parent.views[self.view_id]
class SsdFlatParameterViewParameterization(torch.nn.Module):
def __init__(self, parent: SsdFlatParameter, view_id: int) -> None:
super().__init__()
self.parent = parent
self.view_id = view_id
def forward(self, *args: Any, **kwargs: Any) -> SsdFlatParameterView:
return self.parent.views[self.view_id]
class DisableMemoizationPicklerModule: class DisableMemoizationPicklerModule:
@classmethod @classmethod
def Pickler(cls, data_buf: io.BytesIO, protocol: int) -> pickle.Pickler: def Pickler(cls, data_buf: io.BytesIO, protocol: int) -> pickle.Pickler:
......
...@@ -66,7 +66,6 @@ else: ...@@ -66,7 +66,6 @@ else:
try: try:
import fairscale.experimental.nn.ssd_offload as ssd_offload import fairscale.experimental.nn.ssd_offload as ssd_offload
from fairscale.experimental.nn.ssd_offload import SsdFlatParameter
import_ssd_offload = True import_ssd_offload = True
except ImportError: except ImportError:
...@@ -399,6 +398,10 @@ class FullyShardedDataParallel(nn.Module): ...@@ -399,6 +398,10 @@ class FullyShardedDataParallel(nn.Module):
self.state_dict_on_rank_0_only = state_dict_on_rank_0_only self.state_dict_on_rank_0_only = state_dict_on_rank_0_only
# Experimental feature for now. Use at your own risk. # Experimental feature for now. Use at your own risk.
self.ssd_offload = True if offload_config and offload_config.offload_type == "ssd_offload" else False self.ssd_offload = True if offload_config and offload_config.offload_type == "ssd_offload" else False
if self.ssd_offload and not import_ssd_offload:
raise ImportError(
f"Trying to enable ssd_offload when it was not successfully imported (likely due to old torch version, current = {torch.__version__})"
)
self.gradient_predivide_factor: float = gradient_predivide_factor or self._get_gradient_predivide_factor( self.gradient_predivide_factor: float = gradient_predivide_factor or self._get_gradient_predivide_factor(
self.world_size self.world_size
...@@ -761,16 +764,14 @@ class FullyShardedDataParallel(nn.Module): ...@@ -761,16 +764,14 @@ class FullyShardedDataParallel(nn.Module):
p._is_sharded = True p._is_sharded = True
# Replace p.data with the relevant shard. # Replace p.data with the relevant shard.
orig_data = p.data
p.data, num_padded = self._get_shard(p.data)
self.numel_padded_per_param.append(num_padded)
if self.ssd_offload: if self.ssd_offload:
assert isinstance(p, SsdFlatParameter) assert isinstance(p, ssd_offload.SsdParameter)
sharded_tensor, num_padded = self._get_shard(p.data)
p.point_to_resized_tensor(sharded_tensor)
self.numel_padded_per_param.append(num_padded)
p.to_file() p.to_file()
else: else:
orig_data = p.data
p.data, num_padded = self._get_shard(p.data)
self.numel_padded_per_param.append(num_padded)
free_storage_(orig_data) free_storage_(orig_data)
assert len(self.numel_padded_per_param) == len(self.params) assert len(self.numel_padded_per_param) == len(self.params)
...@@ -987,7 +988,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -987,7 +988,7 @@ class FullyShardedDataParallel(nn.Module):
def _move_params_to_memory(self) -> None: def _move_params_to_memory(self) -> None:
"""Move params from disk to CPU.""" """Move params from disk to CPU."""
for p in self.params: for p in self.params:
assert isinstance(p, SsdFlatParameter) assert isinstance(p, ssd_offload.SsdParameter)
p.to_tensor() p.to_tensor()
def _load_state_dict( def _load_state_dict(
...@@ -1138,14 +1139,19 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1138,14 +1139,19 @@ class FullyShardedDataParallel(nn.Module):
# Copy any changes made to the full params back into # Copy any changes made to the full params back into
# the corresponding local shards. # the corresponding local shards.
local_shard, _ = self._get_shard(full_tensor) local_shard, _ = self._get_shard(full_tensor)
p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard)) if self.ssd_offload:
assert isinstance(p, ssd_offload.SsdParameter)
self._ssd_offload_reset_param_device(p)
p.point_to_tensor(local_shard.view_as(p._fp32_shard).cpu())
else:
p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
if safe_to_free: if safe_to_free:
free_storage_(full_tensor) free_storage_(full_tensor)
self.has_full_params = False self.has_full_params = False
if self.ssd_offload: if self.ssd_offload:
# Store tensors in the SSD buffer and free param storage. # Store tensors in the SSD buffer and free param storage.
for p in self.params: for p in self.params:
assert isinstance(p, SsdFlatParameter) assert isinstance(p, ssd_offload.SsdParameter)
p.to_file() p.to_file()
else: else:
self._use_fp32_param_shard() self._use_fp32_param_shard()
...@@ -1281,10 +1287,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1281,10 +1287,11 @@ class FullyShardedDataParallel(nn.Module):
# shard in pinned memory so that we can do a non-blocking transfer. # shard in pinned memory so that we can do a non-blocking transfer.
# This is only needed during training and not evaluation. # This is only needed during training and not evaluation.
if self.ssd_offload: if self.ssd_offload:
assert isinstance(p, SsdFlatParameter) assert isinstance(p, ssd_offload.SsdParameter)
# Gradients also need to be offloaded to SSD otherwise it can result in # Gradients also need to be offloaded to SSD otherwise it can result in
# OOMs when the memory requirements of a model are larger than host memory. # OOMs when the memory requirements of a model are larger than host memory.
p._cpu_grad = ssd_offload.SsdTensorHandle.from_tensor(torch.zeros_like(p.data, device="cpu")) p._cpu_grad = ssd_offload.SsdTensorHandle.from_tensor(torch.zeros_like(p.data, device="cpu"))
p._cpu_grad.allow_unsafe_changes = True
p._cpu_grad.set_file_params(p.filename + "_grad", 0) p._cpu_grad.set_file_params(p.filename + "_grad", 0)
p._cpu_grad.to_file() p._cpu_grad.to_file()
else: else:
...@@ -1441,7 +1448,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1441,7 +1448,7 @@ class FullyShardedDataParallel(nn.Module):
def _free_ssd_offload(self) -> None: def _free_ssd_offload(self) -> None:
if self.ssd_offload: if self.ssd_offload:
for p in self.params: for p in self.params:
assert isinstance(p, SsdFlatParameter) assert isinstance(p, ssd_offload.SsdParameter)
p.to_file(permit_when_tensor_none=True) p.to_file(permit_when_tensor_none=True)
def _register_pre_backward_hooks(self, outputs: Any) -> Any: def _register_pre_backward_hooks(self, outputs: Any) -> Any:
...@@ -1897,8 +1904,10 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1897,8 +1904,10 @@ class FullyShardedDataParallel(nn.Module):
if self.ssd_offload: if self.ssd_offload:
for p in self.params: for p in self.params:
assert isinstance(p, SsdFlatParameter) assert isinstance(p, ssd_offload.SsdParameter)
p.to_tensor() if not p.is_available():
self._ssd_offload_reset_param_device(p)
p.to_tensor()
self.has_full_params = False self.has_full_params = False
...@@ -2175,13 +2184,25 @@ class FullyShardedDataParallel(nn.Module): ...@@ -2175,13 +2184,25 @@ class FullyShardedDataParallel(nn.Module):
return consolidated_weights return consolidated_weights
@torch.no_grad()
def _ssd_offload_reset_param_device(self, param: Parameter) -> None:
assert isinstance(param, ssd_offload.SsdParameter)
if param.device != torch.device("cpu"):
param.data = param._fp32_shard
param.tensor = None
@torch.no_grad() @torch.no_grad()
def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None: def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
"""Use FP32 shard for a list of params.""" """Use FP32 shard for a list of params."""
if params is None: if params is None:
params = self.params params = self.params
for p in params: for p in params:
p.data = p._fp32_shard if import_ssd_offload and self.ssd_offload:
assert isinstance(p, ssd_offload.SsdParameter)
self._ssd_offload_reset_param_device(p)
p.to_tensor()
else:
p.data = p._fp32_shard
@torch.no_grad() @torch.no_grad()
def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = None) -> None: def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = None) -> None:
...@@ -2192,11 +2213,14 @@ class FullyShardedDataParallel(nn.Module): ...@@ -2192,11 +2213,14 @@ class FullyShardedDataParallel(nn.Module):
for p in params: for p in params:
assert p._fp16_shard is not None assert p._fp16_shard is not None
alloc_storage_(p._fp16_shard, size=p._fp32_shard.size()) alloc_storage_(p._fp16_shard, size=p._fp32_shard.size())
p._fp16_shard.copy_( if self.ssd_offload:
# If move_params_to_cpu is True, this will be non-blocking p._fp16_shard.copy_(p.to(p._fp16_shard.device, non_blocking=True))
# because _fp32_shard is pinned, otherwise it's a no-op. else:
p._fp32_shard.to(p._fp16_shard.device, non_blocking=True) p._fp16_shard.copy_(
) # If move_params_to_cpu is True, this will be non-blocking
# because _fp32_shard is pinned, otherwise it's a no-op.
p._fp32_shard.to(p._fp16_shard.device, non_blocking=True)
)
p.data = p._fp16_shard p.data = p._fp16_shard
torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"]) torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"])
......
...@@ -31,7 +31,19 @@ import torch ...@@ -31,7 +31,19 @@ import torch
from torch import Tensor from torch import Tensor
import torch.nn as nn import torch.nn as nn
from fairscale.experimental.nn.ssd_offload import SsdFlatParameter try:
from fairscale.experimental.nn.ssd_offload import (
SsdFlatParameter,
SsdFlatParameterView,
SsdFlatParameterViewProperty,
_register_property,
)
import_ssd_offload = True
except ImportError:
import_ssd_offload = False
pass
from fairscale.utils.state_dict import replace_by_prefix_ from fairscale.utils.state_dict import replace_by_prefix_
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -116,7 +128,6 @@ class FlatParameter(nn.Parameter): ...@@ -116,7 +128,6 @@ class FlatParameter(nn.Parameter):
# Static types. # Static types.
FlatTypes = Union[FlatParameter, SsdFlatParameter]
ParamGroups = Optional[Union[List[List[nn.Parameter]], List[nn.Parameter]]] ParamGroups = Optional[Union[List[List[nn.Parameter]], List[nn.Parameter]]]
...@@ -159,6 +170,11 @@ class FlattenParamsWrapper(nn.Module): ...@@ -159,6 +170,11 @@ class FlattenParamsWrapper(nn.Module):
ssd_directory: str = "", ssd_directory: str = "",
): ):
super().__init__() super().__init__()
if ssd_offload and not import_ssd_offload:
raise ImportError(
f"Trying to enable ssd_offload when it was not successfully imported (likely due to old torch version, current = {torch.__version__})"
)
self.ssd_offload = ssd_offload
self._fpw_module = module self._fpw_module = module
self.is_flattened = False self.is_flattened = False
...@@ -205,7 +221,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -205,7 +221,7 @@ class FlattenParamsWrapper(nn.Module):
# support. # support.
raise ValueError(f"Incorrect param groups {len(overall_param_set)} vs {self.num_param_managed}") raise ValueError(f"Incorrect param groups {len(overall_param_set)} vs {self.num_param_managed}")
self.flat_params: List[FlatTypes] = [] self.flat_params: List[nn.Parameter] = []
# Prepare flat param names. # Prepare flat param names.
if flat_param_names is None: if flat_param_names is None:
...@@ -215,7 +231,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -215,7 +231,7 @@ class FlattenParamsWrapper(nn.Module):
if len(flat_param_names) != len(set(flat_param_names)): if len(flat_param_names) != len(set(flat_param_names)):
raise ValueError("Each flat param must be given a unique name") raise ValueError("Each flat param must be given a unique name")
self.flat_param_names = [f"flat_param_{n}" for n in flat_param_names] self.flat_param_names = [f"flat_param_{n}" for n in flat_param_names]
flat_param: Optional[FlatTypes] = None flat_param: Optional[nn.Parameter] = None
# Init all flat_params. # Init all flat_params.
for new_p_set in self._param_sets: for new_p_set in self._param_sets:
...@@ -224,6 +240,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -224,6 +240,7 @@ class FlattenParamsWrapper(nn.Module):
assert ssd_directory != "" assert ssd_directory != ""
(handle, fname) = tempfile.mkstemp(dir=ssd_directory, suffix="ssd_buf_param") (handle, fname) = tempfile.mkstemp(dir=ssd_directory, suffix="ssd_buf_param")
flat_param = SsdFlatParameter.from_tensors(tensors=params) flat_param = SsdFlatParameter.from_tensors(tensors=params)
flat_param.allow_unsafe_changes = True
flat_param.set_file_params(fname, 0) flat_param.set_file_params(fname, 0)
else: else:
flat_param = FlatParameter(params, params[0].requires_grad) flat_param = FlatParameter(params, params[0].requires_grad)
...@@ -309,13 +326,13 @@ class FlattenParamsWrapper(nn.Module): ...@@ -309,13 +326,13 @@ class FlattenParamsWrapper(nn.Module):
@property @property
def _param_infos(self) -> Iterator[Tuple[str, nn.Module, str]]: def _param_infos(self) -> Iterator[Tuple[str, nn.Module, str]]:
return chain(*[p._param_infos for p in self.flat_params]) return chain(*[p._param_infos for p in self.flat_params]) # type: ignore
@property @property
def _shared_param_infos(self) -> Iterator[Tuple[str, str, nn.Module, str, nn.Module, str]]: def _shared_param_infos(self) -> Iterator[Tuple[str, str, nn.Module, str, nn.Module, str]]:
return chain(*[p._shared_param_infos for p in self.flat_params]) return chain(*[p._shared_param_infos for p in self.flat_params]) # type: ignore
def _flatten_params(self, flat_params: List[FlatTypes]) -> None: def _flatten_params(self, flat_params: List[nn.Parameter]) -> None:
"""Flatten the managed parameters and replaced the original """Flatten the managed parameters and replaced the original
attributes with views to the flat params. attributes with views to the flat params.
""" """
...@@ -331,6 +348,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -331,6 +348,7 @@ class FlattenParamsWrapper(nn.Module):
# deregister the names as parameters # deregister the names as parameters
for _, m, n in self._param_infos: for _, m, n in self._param_infos:
delattr(m, n) delattr(m, n)
for _, _, m, n, _, _ in self._shared_param_infos: for _, _, m, n, _, _ in self._shared_param_infos:
delattr(m, n) delattr(m, n)
...@@ -372,8 +390,13 @@ class FlattenParamsWrapper(nn.Module): ...@@ -372,8 +390,13 @@ class FlattenParamsWrapper(nn.Module):
ps = self.get_param_views() ps = self.get_param_views()
param_views = [] param_views = []
for (_, m, n), p in zip(self._param_infos, ps): for (_, m, n), p in zip(self._param_infos, ps):
setattr(m, n, p) # This will set as plain attr if self.ssd_offload:
param_views.append(p) assert isinstance(p, SsdFlatParameterView)
_register_property(m, n, SsdFlatParameterViewProperty(p.parent, p.id))
else:
setattr(m, n, p) # This will set as plain attr
param_views.append(p)
# Save param views for easy access if anyone still wants to access # Save param views for easy access if anyone still wants to access
# parameters of the module. # parameters of the module.
...@@ -498,13 +521,13 @@ class FlattenParamsWrapper(nn.Module): ...@@ -498,13 +521,13 @@ class FlattenParamsWrapper(nn.Module):
gens = [] gens = []
for p, data in zip(params, external_data_list): for p, data in zip(params, external_data_list):
gens.append(p.get_param_views(data)) gens.append(p.get_param_views(data)) # type: ignore
return chain(*gens) return chain(*gens)
def metadata(self, flat_param_idx: int) -> Tuple[List[str], Sequence[torch.Size], List[int]]: def metadata(self, flat_param_idx: int) -> Tuple[List[str], Sequence[torch.Size], List[int]]:
"""Return metadata for a flat param given its index in the flat_params list.""" """Return metadata for a flat param given its index in the flat_params list."""
return self.flat_params[flat_param_idx].metadata() return self.flat_params[flat_param_idx].metadata() # type: ignore
def _post_state_dict_hook( def _post_state_dict_hook(
......
...@@ -8,6 +8,7 @@ Testing SsdFlatParameter and SsdTensorHandle modules. ...@@ -8,6 +8,7 @@ Testing SsdFlatParameter and SsdTensorHandle modules.
""" """
import filecmp import filecmp
import functools
import os import os
import tempfile import tempfile
...@@ -15,11 +16,12 @@ import numpy as np ...@@ -15,11 +16,12 @@ import numpy as np
import pytest import pytest
import torch import torch
import fairscale.experimental.nn.ssd_offload as so try:
from fairscale.utils import torch_version import fairscale.experimental.nn.ssd_offload as so
except ImportError as ie:
# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release. # Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
pytestmark = pytest.mark.skipif(torch_version() < (1, 11, 0), reason="requires torch version >= 1.11.0") pytestmark = pytest.mark.skipif(True, reason=ie.msg)
pass
def _init(): def _init():
...@@ -77,10 +79,32 @@ def test_ssd_handle_dispatch_bwd(): ...@@ -77,10 +79,32 @@ def test_ssd_handle_dispatch_bwd():
assert torch.equal(ssd_handle.grad, orig_copy.grad) assert torch.equal(ssd_handle.grad, orig_copy.grad)
def test_ssd_handle_train_simple(): def test_ssd_handle_dispatch_bwd_hook():
if torch_version() >= (1, 12, 0): _init()
pytest.skip("to be fixed")
def post_backward_hook(name, grad):
print(f"BACKWARD HOOK for tensor {name} CALLED")
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4), requires_grad=True)
orig_copy = orig_tensor.clone().detach().requires_grad_(True)
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
ssd_handle.set_file_params(f.name, 0)
ssd_handle.to_file(release_tensor_after_write=True)
one = torch.ones((1), requires_grad=True).cuda()
orig_copy = ssd_handle.data
cuda_copy = ssd_handle.to("cuda").detach().requires_grad_(True)
ssd_handle.data = cuda_copy
ssd_handle.register_hook(functools.partial(post_backward_hook, "ssd_handle"))
one.register_hook(functools.partial(post_backward_hook, "one"))
y1 = ssd_handle + one
y1.sum().backward()
def test_ssd_handle_train_simple():
_init() _init()
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
...@@ -92,6 +116,7 @@ def test_ssd_handle_train_simple(): ...@@ -92,6 +116,7 @@ def test_ssd_handle_train_simple():
orig_copy.requires_grad = True orig_copy.requires_grad = True
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor) ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
ssd_handle.flush_on_dirty = False
ssd_handle.set_file_params(f.name, 0) ssd_handle.set_file_params(f.name, 0)
ssd_handle.to_file(release_tensor_after_write=True) ssd_handle.to_file(release_tensor_after_write=True)
...@@ -102,15 +127,15 @@ def test_ssd_handle_train_simple(): ...@@ -102,15 +127,15 @@ def test_ssd_handle_train_simple():
y1 = ssd_handle + 1 y1 = ssd_handle + 1
optimizer_ssd.zero_grad() optimizer_ssd.zero_grad()
y1.sum().backward() y1.sum().backward()
assert ssd_handle.storage_state is so.StorageState.ON_CPU_CLEAN
optimizer_ssd.step() optimizer_ssd.step()
assert ssd_handle.storage_state is so.StorageState.ON_CPU_DIRTY
y2 = orig_copy + 1 y2 = orig_copy + 1
optimizer_orig.zero_grad() optimizer_orig.zero_grad()
y2.sum().backward() y2.sum().backward()
optimizer_orig.step() optimizer_orig.step()
# make sure we are using the file version not the cached tensor
ssd_handle.point_to_file(f.name, 0)
assert torch.equal(ssd_handle.to_tensor(), orig_copy) assert torch.equal(ssd_handle.to_tensor(), orig_copy)
...@@ -169,9 +194,6 @@ def test_torch_save_load_ssd_flat_param_on_mem(): ...@@ -169,9 +194,6 @@ def test_torch_save_load_ssd_flat_param_on_mem():
def test_ssd_param_train_simple(): def test_ssd_param_train_simple():
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
_init() _init()
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4)) orig_tensor = torch.randn((4, 4))
...@@ -183,6 +205,7 @@ def test_ssd_param_train_simple(): ...@@ -183,6 +205,7 @@ def test_ssd_param_train_simple():
ssd_param = so.SsdParameter(orig_tensor.shape, orig_tensor.dtype) ssd_param = so.SsdParameter(orig_tensor.shape, orig_tensor.dtype)
ssd_param.point_to_tensor(orig_copy) ssd_param.point_to_tensor(orig_copy)
ssd_param.flush_on_dirty = False
ssd_param.set_file_params(f.name, 0) ssd_param.set_file_params(f.name, 0)
ssd_param.to_file(release_tensor_after_write=True) ssd_param.to_file(release_tensor_after_write=True)
...@@ -193,15 +216,17 @@ def test_ssd_param_train_simple(): ...@@ -193,15 +216,17 @@ def test_ssd_param_train_simple():
y1 = ssd_param + 1 y1 = ssd_param + 1
optimizer_ssd.zero_grad() optimizer_ssd.zero_grad()
y1.sum().backward() y1.sum().backward()
# Test to see if Dirty is being calculated correctly when optimizer modifies
# ssd_param
assert ssd_param.storage_state is so.StorageState.ON_CPU_CLEAN
optimizer_ssd.step() optimizer_ssd.step()
assert ssd_param.storage_state is so.StorageState.ON_CPU_DIRTY
y2 = param + 1 y2 = param + 1
optimizer_orig.zero_grad() optimizer_orig.zero_grad()
y2.sum().backward() y2.sum().backward()
optimizer_orig.step() optimizer_orig.step()
# make sure we are using the file version not the cached tensor
ssd_param.point_to_file(f.name, 0)
assert torch.equal(ssd_param.to_tensor(), param) assert torch.equal(ssd_param.to_tensor(), param)
...@@ -211,8 +236,186 @@ def test_ssd_flat_parameter_basic(): ...@@ -211,8 +236,186 @@ def test_ssd_flat_parameter_basic():
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32)) refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32)) refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32)) refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32))
ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], False) ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], direct_to_file=False)
ssd_flat_param.set_file_params(f.name, 0)
param_views = list(ssd_flat_param.get_param_views())
assert refa_param.shape == param_views[0].shape
assert refb_param.shape == param_views[1].shape
assert refc_param.shape == param_views[2].shape
assert torch.equal(refa_param, param_views[0])
assert torch.equal(refb_param, param_views[1])
assert torch.equal(refc_param, param_views[2])
ssd_flat_param.to_file()
assert not ssd_flat_param.is_available()
first_value = param_views[0][0][0].item()
assert ssd_flat_param.is_available()
assert first_value == refa_param[0][0].item()
def test_ssd_flat_parameter_view_modify():
_init()
with tempfile.NamedTemporaryFile() as f:
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=False)
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=False)
refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32), requires_grad=False)
ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], direct_to_file=False)
ssd_flat_param.set_file_params(f.name, 0) ssd_flat_param.set_file_params(f.name, 0)
ssd_flat_param.flush_on_dirty = False
param_views = list(ssd_flat_param.get_param_views())
assert ssd_flat_param.storage_state == so.StorageState.ON_CPU_DIRTY
ssd_flat_param.to_file()
assert ssd_flat_param.storage_state == so.StorageState.ON_DISK
assert param_views[0].tensor is None
param_views[0] += 0.1
assert ssd_flat_param.storage_state == so.StorageState.ON_CPU_DIRTY
def test_ssd_flat_parameter_view_bwd():
_init()
hooks_called = []
def post_backward_hook(name, hooks_called, *grads):
print(f"BACKWARD HOOK for tensor {name} CALLED")
hooks_called.append(name)
with tempfile.NamedTemporaryFile() as f:
refa_param = (
torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=True)
.to("cpu")
.detach()
.requires_grad_()
)
refb_param = (
torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=True)
.to("cpu")
.detach()
.requires_grad_()
)
refc_param = (
torch.nn.Parameter(torch.rand((128), dtype=torch.float32), requires_grad=True)
.to("cpu")
.detach()
.requires_grad_()
)
ssd_flat_param = so.SsdFlatParameter.from_tensors(
[refa_param, refb_param, refc_param], direct_to_file=True, filename=f.name, offset=0
)
orig_copy = ssd_flat_param.data
cuda_copy = ssd_flat_param.to("cuda").detach().requires_grad_()
cpu_copy = ssd_flat_param.to("cpu").detach().requires_grad_()
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_orig", hooks_called))
ssd_flat_param.data = cuda_copy
one = torch.ones((1), requires_grad=True, device=ssd_flat_param.device)
y1 = ssd_flat_param.views[0] + one
y2 = cuda_copy + 1
# ssd_flat_param.to_file()
# ssd_flat_param.data = orig_copy
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_cuda", hooks_called))
ssd_flat_param.views[0].register_hook(
functools.partial(post_backward_hook, "ssd_flat_param.views[0]", hooks_called)
)
ssd_flat_param.register_hook(functools.partial(post_backward_hook, "ssd_flat_param", hooks_called))
one.register_hook(functools.partial(post_backward_hook, "one", hooks_called))
y1.sum().backward()
y2.sum().backward()
assert "GradAccumulation_cuda" in hooks_called
assert "ssd_flat_param.views[0]" in hooks_called
assert "ssd_flat_param" in hooks_called
assert "one" in hooks_called
def test_ssd_flat_parameter_view_bwd_parameterization():
_init()
hooks_called = []
def post_backward_hook(name, hooks_called, *grads):
print(f"BACKWARD HOOK for tensor {name} CALLED")
hooks_called.append(name)
with tempfile.NamedTemporaryFile() as f:
layer1 = torch.nn.Linear(32, 4, bias=False)
layer2 = torch.nn.Linear(32, 4, bias=False)
layer3 = torch.nn.Linear(128, 1, bias=False)
ssd_flat_param = so.SsdFlatParameter.from_tensors(
[layer1.weight, layer2.weight, layer3.weight], direct_to_file=False, filename=f.name, offset=0
)
torch.nn.utils.parametrize.register_parametrization(
layer1, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 0)
)
torch.nn.utils.parametrize.register_parametrization(
layer2, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 1)
)
torch.nn.utils.parametrize.register_parametrization(
layer3, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 2)
)
orig_copy = ssd_flat_param.data
cuda_copy = ssd_flat_param.to("cuda").detach().requires_grad_()
cpu_copy = ssd_flat_param.to("cpu").detach().requires_grad_()
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_orig", hooks_called))
ssd_flat_param.to_file(release_tensor_after_write=False)
ssd_flat_param.data = cuda_copy
one = torch.ones(layer1.weight.shape, requires_grad=True, device=ssd_flat_param.device)
y1 = layer1.forward(one)
y2 = cuda_copy + 1
# ssd_flat_param.to_file()
# ssd_flat_param.data = orig_copy
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_cuda", hooks_called))
ssd_flat_param.views[0].register_hook(
functools.partial(post_backward_hook, "ssd_flat_param.views[0]", hooks_called)
)
ssd_flat_param.register_hook(functools.partial(post_backward_hook, "ssd_flat_param", hooks_called))
one.register_hook(functools.partial(post_backward_hook, "one", hooks_called))
y1.sum().backward()
y2.sum().backward()
assert "GradAccumulation_cuda" in hooks_called
assert "ssd_flat_param.views[0]" in hooks_called
assert "ssd_flat_param" in hooks_called
assert "one" in hooks_called
def test_ssd_flat_parameter_direct_to_file():
_init()
with tempfile.NamedTemporaryFile() as f:
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32))
ssd_flat_param = so.SsdFlatParameter.from_tensors(
[refa_param, refb_param, refc_param], direct_to_file=True, filename=f.name, offset=0
)
param_views = list(ssd_flat_param.get_param_views()) param_views = list(ssd_flat_param.get_param_views())
...@@ -224,3 +427,8 @@ def test_ssd_flat_parameter_basic(): ...@@ -224,3 +427,8 @@ def test_ssd_flat_parameter_basic():
assert torch.equal(refb_param, param_views[1]) assert torch.equal(refb_param, param_views[1])
assert torch.equal(refc_param, param_views[2]) assert torch.equal(refc_param, param_views[2])
ssd_flat_param.to_file() ssd_flat_param.to_file()
assert not ssd_flat_param.is_available()
first_value = param_views[0][0][0].item()
assert ssd_flat_param.is_available()
assert first_value == refa_param[0][0].item()
...@@ -16,17 +16,17 @@ import torch ...@@ -16,17 +16,17 @@ import torch
from torch import nn from torch import nn
import torch.distributed import torch.distributed
import fairscale.experimental.nn.ssd_offload as so try:
import fairscale.experimental.nn.ssd_offload as so
except ImportError as ie:
# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
pytestmark = pytest.mark.skipif(True, reason=ie.msg)
pass
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState from fairscale.nn.data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes
# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
print(f"torch version {torch_version()}")
pytestmark = pytest.mark.skipif(torch_version() < (1, 11, 0), reason="requires torch version >= 1.11.0")
# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4 # How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# All helper functions called by spawn must be either @classmethod, @staticmethod # All helper functions called by spawn must be either @classmethod, @staticmethod
...@@ -137,8 +137,6 @@ def rename_test(testcase_func, param_num, param): ...@@ -137,8 +137,6 @@ def rename_test(testcase_func, param_num, param):
class TestSsdMemory(DistributedTest): class TestSsdMemory(DistributedTest):
def test_memory_benchmark(self): def test_memory_benchmark(self):
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
test_fn = functools.partial(self._test_memory_benchmark, config={}) test_fn = functools.partial(self._test_memory_benchmark, config={})
spawn_and_init(test_fn) spawn_and_init(test_fn)
...@@ -218,8 +216,6 @@ class TimeKeeper: ...@@ -218,8 +216,6 @@ class TimeKeeper:
class TestModuleProperties(DistributedTest): class TestModuleProperties(DistributedTest):
@parameterized.expand(CONFIG, name_func=rename_test) @parameterized.expand(CONFIG, name_func=rename_test)
def test_named_parameters(self, config): def test_named_parameters(self, config):
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
test_fn = functools.partial(self._test_named_params, config=config) test_fn = functools.partial(self._test_named_params, config=config)
spawn_and_init(test_fn) spawn_and_init(test_fn)
...@@ -264,23 +260,17 @@ class TestModuleProperties(DistributedTest): ...@@ -264,23 +260,17 @@ class TestModuleProperties(DistributedTest):
class TestSsdLoading(DistributedTest): class TestSsdLoading(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_ssd_offloading_eval(self, config): def test_ssd_offloading_eval(self, config):
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
test_fn = functools.partial(self._test_ssd_offload_eval, config=config) test_fn = functools.partial(self._test_ssd_offload_eval, config=config)
spawn_and_init(test_fn) spawn_and_init(test_fn)
@parameterized.expand(CONFIG, name_func=rename_test) @parameterized.expand(CONFIG, name_func=rename_test)
def test_transformer_parameterized(self, config): def test_transformer_parameterized(self, config):
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
spawn_and_init(functools.partial(self._test_identical_outputs_eval, TransformerWithSharedParams, config)) spawn_and_init(functools.partial(self._test_identical_outputs_eval, TransformerWithSharedParams, config))
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_ssd_offloading_train_flatten_params_wrapper(self, config): def test_ssd_offloading_train_flatten_params_wrapper(self, config):
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
test_fn = functools.partial(self._test_ssd_offloading_train_flatten_params_wrapper, config=config) test_fn = functools.partial(self._test_ssd_offloading_train_flatten_params_wrapper, config=config)
spawn_and_init(test_fn) spawn_and_init(test_fn)
...@@ -288,6 +278,8 @@ class TestSsdLoading(DistributedTest): ...@@ -288,6 +278,8 @@ class TestSsdLoading(DistributedTest):
@classmethod @classmethod
def _test_ssd_offloading_train_flatten_params_wrapper(self, rank, group, config): def _test_ssd_offloading_train_flatten_params_wrapper(self, rank, group, config):
SIZE = 16 * 16 SIZE = 16 * 16
LR = 0.01
MOMENTUM = 0.1
model = SimpleLinear(group, input_size=SIZE, output_size=SIZE, layers=4) model = SimpleLinear(group, input_size=SIZE, output_size=SIZE, layers=4)
with tempfile.TemporaryDirectory() as current_tempdir: with tempfile.TemporaryDirectory() as current_tempdir:
...@@ -305,7 +297,7 @@ class TestSsdLoading(DistributedTest): ...@@ -305,7 +297,7 @@ class TestSsdLoading(DistributedTest):
model = FullyShardedDataParallel(model, **config) model = FullyShardedDataParallel(model, **config)
model_device = torch.device("cuda") model_device = torch.device("cuda")
model.train() model.train()
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) optim = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
checkpoint_file = tempfile.NamedTemporaryFile() checkpoint_file = tempfile.NamedTemporaryFile()
checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir") checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir")
...@@ -322,9 +314,13 @@ class TestSsdLoading(DistributedTest): ...@@ -322,9 +314,13 @@ class TestSsdLoading(DistributedTest):
input = model.get_input(torch.device("cuda")) input = model.get_input(torch.device("cuda"))
output = model(*input) output = model(*input)
pre_checkpoint_last_output = output pre_checkpoint_last_output = output
"""
param_itr = iter(model.named_parameters())
p_name, p_val = next(param_itr)
print(f"i={i} pre_checkpoint {p_name} = {p_val[0].item()}")
"""
loss = model.module.get_loss(input, output).to(model_device) loss = model.module.get_loss(input, output).to(model_device)
assert loss.dtype == torch.float32 assert loss.dtype == torch.float32
model.module.run_backward(loss) model.module.run_backward(loss)
optim.step() optim.step()
if i == 0: if i == 0:
...@@ -332,18 +328,23 @@ class TestSsdLoading(DistributedTest): ...@@ -332,18 +328,23 @@ class TestSsdLoading(DistributedTest):
# so.torch_saver.save({"model": model.state_dict(), "optim": optim.state_dict()}, checkpoint_file.name) # so.torch_saver.save({"model": model.state_dict(), "optim": optim.state_dict()}, checkpoint_file.name)
torch.save({"model": model.state_dict()}, checkpoint_file.name) torch.save({"model": model.state_dict()}, checkpoint_file.name)
# reset momentum just after checkpoint save # reset momentum just after checkpoint save
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) optim = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
checkpoint = torch.load(checkpoint_file.name) checkpoint = torch.load(checkpoint_file.name)
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])
# reset momentum just after checkpoint load # reset momentum just after checkpoint load
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) optim = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
# do more iterations after loading checkpoint # do more iterations after loading checkpoint
for i in range(ITERATIONS - 1): for i in range(ITERATIONS - 1):
optim.zero_grad() optim.zero_grad()
input = model.get_input(torch.device("cuda")) input = model.get_input(torch.device("cuda"))
output = model(*input) output = model(*input)
post_checkpoint_last_output = output post_checkpoint_last_output = output
"""
param_itr = iter(model.named_parameters())
p_name, p_val = next(param_itr)
print(f"i={i} post_checkpoint {p_name} = {p_val[0].item()}")
"""
loss = model.module.get_loss(input, output).to(model_device) loss = model.module.get_loss(input, output).to(model_device)
assert loss.dtype == torch.float32 assert loss.dtype == torch.float32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment