Unverified Commit 51b53ddb authored by Paul Johnson's avatar Paul Johnson Committed by GitHub
Browse files

[FSDP] ssd_offload fixing backward path (grad_fn) for SsdFlatParameter and...

[FSDP] ssd_offload fixing backward path (grad_fn) for SsdFlatParameter and SsdFlatParameterView (#974)

* [FSDP] fixing backward path for SsdFlatParameter and SsdFlatParameterView when overriding .data

* Get ssd_offload unit tests passing

* [FSDP] get all test_fsdp_offload tests passing w/ ssd_offload on

* Update changelog
parent 58ccb166
...@@ -12,7 +12,40 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -12,7 +12,40 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- FSDP: Add pickle/unpickle support for SsdTensorHandle (and derived classes), - FSDP: Add pickle/unpickle support for SsdTensorHandle (and derived classes),
verified that FSDP models w/ ssd_offload enabled can correctly call model.state_dict() verified that FSDP models w/ ssd_offload enabled can correctly call model.state_dict()
and model.load_state_dict(...) and thus successfully checkpoint and recover parameters and model.load_state_dict(...) and thus successfully checkpoint and recover parameters
stored as SsdFlatParameters. stored as SsdFlatParameters. [#964]
- FSDP Ssd Offload: [#974]
* add __setattr__ function for SsdTensorHandle/SsdFlatParameter to capture when .data is overridden
and perform necessary checks/updates before letting the tensor metadata be updated.
* There was a bug in pytorch core that caused re-assigning TensorSubclass .data to
disable the __torch_dispatch__ mechanism and effectively revert back to a normal
tensor. Since ssd_offload feature uses __torch_dispatch__ extensively and FSDP overrides
.data, ssd_offload now can only be used if pytorch version is 1.12.0 or later
(currently pytorch-nightly release). pytest tests are disabled and trying to import
ssd_offload or use FSDP with ssd_offload enabled will raise ImportError Exceptions.
* Enhance storage_state ON_CPU into ON_CPU_CLEAN and ON_CPU_DIRTY. ON_CPU_CLEAN
indicates that the .tensor value and the value stored on disk are identical, and no
writes are needed when calling .to_file(). ON_CPU_DIRTY indicates .tensor does not
match value stored on disk, and a write is necessary. This is disabled in FSDP as
it does not currently flush variables to disk as soon as optimizer modifies values.
* Fix detection in SsdTensorHandle __torch_dispatch__ when it is used in an in-place
operator, or as an output, then calling mark_dirty().
* Fix grad_fn on SsdFlatParameterViews so that both ssd_flat_param.view.grad_fn points
to a split and view on ssd_flat_parameter so that when X.backwards() is called. ssd_flat_param.grad
is properly updated in the backward pass.
* Add unit tests to verify grad_fn and backward hooks are called appropriately on
SsdFlatParameters/SsdFlatParameterViews
* Implement SsdFlatParameter.from_tensor() direct_to_file option. This allows creating
and SsdFlatParameter and directly writing it to disk, rather than creating a new tensor
first. This prevents another copy of tensors to be instantiated and can save memory when
creating very large SsdFlatParameters.
* Add SsdFlatParameterViewProperty for overriding a parameter in an nn.Module
similar to pytorch core's parameterization code path. This allows SsdFlatParameterViews
to be updated as SsdFlatParameter.data is overridden by replacing Layer.weight with a
property whose getter returns ssd_flat_param.views[view_id].
* Added example SsdFlatParameterViewParameterization on how the existing parameterization
method could be used instead of SsdFlatParameterViewProperty. But this method will result
in memory inefficiencies due to parameterization always keeping a copy of the original
parameter internally.
### Fixed ### Fixed
......
This diff is collapsed.
...@@ -66,7 +66,6 @@ else: ...@@ -66,7 +66,6 @@ else:
try: try:
import fairscale.experimental.nn.ssd_offload as ssd_offload import fairscale.experimental.nn.ssd_offload as ssd_offload
from fairscale.experimental.nn.ssd_offload import SsdFlatParameter
import_ssd_offload = True import_ssd_offload = True
except ImportError: except ImportError:
...@@ -399,6 +398,10 @@ class FullyShardedDataParallel(nn.Module): ...@@ -399,6 +398,10 @@ class FullyShardedDataParallel(nn.Module):
self.state_dict_on_rank_0_only = state_dict_on_rank_0_only self.state_dict_on_rank_0_only = state_dict_on_rank_0_only
# Experimental feature for now. Use at your own risk. # Experimental feature for now. Use at your own risk.
self.ssd_offload = True if offload_config and offload_config.offload_type == "ssd_offload" else False self.ssd_offload = True if offload_config and offload_config.offload_type == "ssd_offload" else False
if self.ssd_offload and not import_ssd_offload:
raise ImportError(
f"Trying to enable ssd_offload when it was not successfully imported (likely due to old torch version, current = {torch.__version__})"
)
self.gradient_predivide_factor: float = gradient_predivide_factor or self._get_gradient_predivide_factor( self.gradient_predivide_factor: float = gradient_predivide_factor or self._get_gradient_predivide_factor(
self.world_size self.world_size
...@@ -761,16 +764,14 @@ class FullyShardedDataParallel(nn.Module): ...@@ -761,16 +764,14 @@ class FullyShardedDataParallel(nn.Module):
p._is_sharded = True p._is_sharded = True
# Replace p.data with the relevant shard. # Replace p.data with the relevant shard.
orig_data = p.data
p.data, num_padded = self._get_shard(p.data)
self.numel_padded_per_param.append(num_padded)
if self.ssd_offload: if self.ssd_offload:
assert isinstance(p, SsdFlatParameter) assert isinstance(p, ssd_offload.SsdParameter)
sharded_tensor, num_padded = self._get_shard(p.data)
p.point_to_resized_tensor(sharded_tensor)
self.numel_padded_per_param.append(num_padded)
p.to_file() p.to_file()
else: else:
orig_data = p.data
p.data, num_padded = self._get_shard(p.data)
self.numel_padded_per_param.append(num_padded)
free_storage_(orig_data) free_storage_(orig_data)
assert len(self.numel_padded_per_param) == len(self.params) assert len(self.numel_padded_per_param) == len(self.params)
...@@ -987,7 +988,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -987,7 +988,7 @@ class FullyShardedDataParallel(nn.Module):
def _move_params_to_memory(self) -> None: def _move_params_to_memory(self) -> None:
"""Move params from disk to CPU.""" """Move params from disk to CPU."""
for p in self.params: for p in self.params:
assert isinstance(p, SsdFlatParameter) assert isinstance(p, ssd_offload.SsdParameter)
p.to_tensor() p.to_tensor()
def _load_state_dict( def _load_state_dict(
...@@ -1138,14 +1139,19 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1138,14 +1139,19 @@ class FullyShardedDataParallel(nn.Module):
# Copy any changes made to the full params back into # Copy any changes made to the full params back into
# the corresponding local shards. # the corresponding local shards.
local_shard, _ = self._get_shard(full_tensor) local_shard, _ = self._get_shard(full_tensor)
p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard)) if self.ssd_offload:
assert isinstance(p, ssd_offload.SsdParameter)
self._ssd_offload_reset_param_device(p)
p.point_to_tensor(local_shard.view_as(p._fp32_shard).cpu())
else:
p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
if safe_to_free: if safe_to_free:
free_storage_(full_tensor) free_storage_(full_tensor)
self.has_full_params = False self.has_full_params = False
if self.ssd_offload: if self.ssd_offload:
# Store tensors in the SSD buffer and free param storage. # Store tensors in the SSD buffer and free param storage.
for p in self.params: for p in self.params:
assert isinstance(p, SsdFlatParameter) assert isinstance(p, ssd_offload.SsdParameter)
p.to_file() p.to_file()
else: else:
self._use_fp32_param_shard() self._use_fp32_param_shard()
...@@ -1281,10 +1287,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1281,10 +1287,11 @@ class FullyShardedDataParallel(nn.Module):
# shard in pinned memory so that we can do a non-blocking transfer. # shard in pinned memory so that we can do a non-blocking transfer.
# This is only needed during training and not evaluation. # This is only needed during training and not evaluation.
if self.ssd_offload: if self.ssd_offload:
assert isinstance(p, SsdFlatParameter) assert isinstance(p, ssd_offload.SsdParameter)
# Gradients also need to be offloaded to SSD otherwise it can result in # Gradients also need to be offloaded to SSD otherwise it can result in
# OOMs when the memory requirements of a model are larger than host memory. # OOMs when the memory requirements of a model are larger than host memory.
p._cpu_grad = ssd_offload.SsdTensorHandle.from_tensor(torch.zeros_like(p.data, device="cpu")) p._cpu_grad = ssd_offload.SsdTensorHandle.from_tensor(torch.zeros_like(p.data, device="cpu"))
p._cpu_grad.allow_unsafe_changes = True
p._cpu_grad.set_file_params(p.filename + "_grad", 0) p._cpu_grad.set_file_params(p.filename + "_grad", 0)
p._cpu_grad.to_file() p._cpu_grad.to_file()
else: else:
...@@ -1441,7 +1448,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1441,7 +1448,7 @@ class FullyShardedDataParallel(nn.Module):
def _free_ssd_offload(self) -> None: def _free_ssd_offload(self) -> None:
if self.ssd_offload: if self.ssd_offload:
for p in self.params: for p in self.params:
assert isinstance(p, SsdFlatParameter) assert isinstance(p, ssd_offload.SsdParameter)
p.to_file(permit_when_tensor_none=True) p.to_file(permit_when_tensor_none=True)
def _register_pre_backward_hooks(self, outputs: Any) -> Any: def _register_pre_backward_hooks(self, outputs: Any) -> Any:
...@@ -1897,8 +1904,10 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1897,8 +1904,10 @@ class FullyShardedDataParallel(nn.Module):
if self.ssd_offload: if self.ssd_offload:
for p in self.params: for p in self.params:
assert isinstance(p, SsdFlatParameter) assert isinstance(p, ssd_offload.SsdParameter)
p.to_tensor() if not p.is_available():
self._ssd_offload_reset_param_device(p)
p.to_tensor()
self.has_full_params = False self.has_full_params = False
...@@ -2175,13 +2184,25 @@ class FullyShardedDataParallel(nn.Module): ...@@ -2175,13 +2184,25 @@ class FullyShardedDataParallel(nn.Module):
return consolidated_weights return consolidated_weights
@torch.no_grad()
def _ssd_offload_reset_param_device(self, param: Parameter) -> None:
assert isinstance(param, ssd_offload.SsdParameter)
if param.device != torch.device("cpu"):
param.data = param._fp32_shard
param.tensor = None
@torch.no_grad() @torch.no_grad()
def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None: def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
"""Use FP32 shard for a list of params.""" """Use FP32 shard for a list of params."""
if params is None: if params is None:
params = self.params params = self.params
for p in params: for p in params:
p.data = p._fp32_shard if import_ssd_offload and self.ssd_offload:
assert isinstance(p, ssd_offload.SsdParameter)
self._ssd_offload_reset_param_device(p)
p.to_tensor()
else:
p.data = p._fp32_shard
@torch.no_grad() @torch.no_grad()
def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = None) -> None: def _cast_fp32_param_shards_to_fp16(self, params: Optional[List[Parameter]] = None) -> None:
...@@ -2192,11 +2213,14 @@ class FullyShardedDataParallel(nn.Module): ...@@ -2192,11 +2213,14 @@ class FullyShardedDataParallel(nn.Module):
for p in params: for p in params:
assert p._fp16_shard is not None assert p._fp16_shard is not None
alloc_storage_(p._fp16_shard, size=p._fp32_shard.size()) alloc_storage_(p._fp16_shard, size=p._fp32_shard.size())
p._fp16_shard.copy_( if self.ssd_offload:
# If move_params_to_cpu is True, this will be non-blocking p._fp16_shard.copy_(p.to(p._fp16_shard.device, non_blocking=True))
# because _fp32_shard is pinned, otherwise it's a no-op. else:
p._fp32_shard.to(p._fp16_shard.device, non_blocking=True) p._fp16_shard.copy_(
) # If move_params_to_cpu is True, this will be non-blocking
# because _fp32_shard is pinned, otherwise it's a no-op.
p._fp32_shard.to(p._fp16_shard.device, non_blocking=True)
)
p.data = p._fp16_shard p.data = p._fp16_shard
torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"]) torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"])
......
...@@ -31,7 +31,19 @@ import torch ...@@ -31,7 +31,19 @@ import torch
from torch import Tensor from torch import Tensor
import torch.nn as nn import torch.nn as nn
from fairscale.experimental.nn.ssd_offload import SsdFlatParameter try:
from fairscale.experimental.nn.ssd_offload import (
SsdFlatParameter,
SsdFlatParameterView,
SsdFlatParameterViewProperty,
_register_property,
)
import_ssd_offload = True
except ImportError:
import_ssd_offload = False
pass
from fairscale.utils.state_dict import replace_by_prefix_ from fairscale.utils.state_dict import replace_by_prefix_
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -116,7 +128,6 @@ class FlatParameter(nn.Parameter): ...@@ -116,7 +128,6 @@ class FlatParameter(nn.Parameter):
# Static types. # Static types.
FlatTypes = Union[FlatParameter, SsdFlatParameter]
ParamGroups = Optional[Union[List[List[nn.Parameter]], List[nn.Parameter]]] ParamGroups = Optional[Union[List[List[nn.Parameter]], List[nn.Parameter]]]
...@@ -159,6 +170,11 @@ class FlattenParamsWrapper(nn.Module): ...@@ -159,6 +170,11 @@ class FlattenParamsWrapper(nn.Module):
ssd_directory: str = "", ssd_directory: str = "",
): ):
super().__init__() super().__init__()
if ssd_offload and not import_ssd_offload:
raise ImportError(
f"Trying to enable ssd_offload when it was not successfully imported (likely due to old torch version, current = {torch.__version__})"
)
self.ssd_offload = ssd_offload
self._fpw_module = module self._fpw_module = module
self.is_flattened = False self.is_flattened = False
...@@ -205,7 +221,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -205,7 +221,7 @@ class FlattenParamsWrapper(nn.Module):
# support. # support.
raise ValueError(f"Incorrect param groups {len(overall_param_set)} vs {self.num_param_managed}") raise ValueError(f"Incorrect param groups {len(overall_param_set)} vs {self.num_param_managed}")
self.flat_params: List[FlatTypes] = [] self.flat_params: List[nn.Parameter] = []
# Prepare flat param names. # Prepare flat param names.
if flat_param_names is None: if flat_param_names is None:
...@@ -215,7 +231,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -215,7 +231,7 @@ class FlattenParamsWrapper(nn.Module):
if len(flat_param_names) != len(set(flat_param_names)): if len(flat_param_names) != len(set(flat_param_names)):
raise ValueError("Each flat param must be given a unique name") raise ValueError("Each flat param must be given a unique name")
self.flat_param_names = [f"flat_param_{n}" for n in flat_param_names] self.flat_param_names = [f"flat_param_{n}" for n in flat_param_names]
flat_param: Optional[FlatTypes] = None flat_param: Optional[nn.Parameter] = None
# Init all flat_params. # Init all flat_params.
for new_p_set in self._param_sets: for new_p_set in self._param_sets:
...@@ -224,6 +240,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -224,6 +240,7 @@ class FlattenParamsWrapper(nn.Module):
assert ssd_directory != "" assert ssd_directory != ""
(handle, fname) = tempfile.mkstemp(dir=ssd_directory, suffix="ssd_buf_param") (handle, fname) = tempfile.mkstemp(dir=ssd_directory, suffix="ssd_buf_param")
flat_param = SsdFlatParameter.from_tensors(tensors=params) flat_param = SsdFlatParameter.from_tensors(tensors=params)
flat_param.allow_unsafe_changes = True
flat_param.set_file_params(fname, 0) flat_param.set_file_params(fname, 0)
else: else:
flat_param = FlatParameter(params, params[0].requires_grad) flat_param = FlatParameter(params, params[0].requires_grad)
...@@ -309,13 +326,13 @@ class FlattenParamsWrapper(nn.Module): ...@@ -309,13 +326,13 @@ class FlattenParamsWrapper(nn.Module):
@property @property
def _param_infos(self) -> Iterator[Tuple[str, nn.Module, str]]: def _param_infos(self) -> Iterator[Tuple[str, nn.Module, str]]:
return chain(*[p._param_infos for p in self.flat_params]) return chain(*[p._param_infos for p in self.flat_params]) # type: ignore
@property @property
def _shared_param_infos(self) -> Iterator[Tuple[str, str, nn.Module, str, nn.Module, str]]: def _shared_param_infos(self) -> Iterator[Tuple[str, str, nn.Module, str, nn.Module, str]]:
return chain(*[p._shared_param_infos for p in self.flat_params]) return chain(*[p._shared_param_infos for p in self.flat_params]) # type: ignore
def _flatten_params(self, flat_params: List[FlatTypes]) -> None: def _flatten_params(self, flat_params: List[nn.Parameter]) -> None:
"""Flatten the managed parameters and replaced the original """Flatten the managed parameters and replaced the original
attributes with views to the flat params. attributes with views to the flat params.
""" """
...@@ -331,6 +348,7 @@ class FlattenParamsWrapper(nn.Module): ...@@ -331,6 +348,7 @@ class FlattenParamsWrapper(nn.Module):
# deregister the names as parameters # deregister the names as parameters
for _, m, n in self._param_infos: for _, m, n in self._param_infos:
delattr(m, n) delattr(m, n)
for _, _, m, n, _, _ in self._shared_param_infos: for _, _, m, n, _, _ in self._shared_param_infos:
delattr(m, n) delattr(m, n)
...@@ -372,8 +390,13 @@ class FlattenParamsWrapper(nn.Module): ...@@ -372,8 +390,13 @@ class FlattenParamsWrapper(nn.Module):
ps = self.get_param_views() ps = self.get_param_views()
param_views = [] param_views = []
for (_, m, n), p in zip(self._param_infos, ps): for (_, m, n), p in zip(self._param_infos, ps):
setattr(m, n, p) # This will set as plain attr if self.ssd_offload:
param_views.append(p) assert isinstance(p, SsdFlatParameterView)
_register_property(m, n, SsdFlatParameterViewProperty(p.parent, p.id))
else:
setattr(m, n, p) # This will set as plain attr
param_views.append(p)
# Save param views for easy access if anyone still wants to access # Save param views for easy access if anyone still wants to access
# parameters of the module. # parameters of the module.
...@@ -498,13 +521,13 @@ class FlattenParamsWrapper(nn.Module): ...@@ -498,13 +521,13 @@ class FlattenParamsWrapper(nn.Module):
gens = [] gens = []
for p, data in zip(params, external_data_list): for p, data in zip(params, external_data_list):
gens.append(p.get_param_views(data)) gens.append(p.get_param_views(data)) # type: ignore
return chain(*gens) return chain(*gens)
def metadata(self, flat_param_idx: int) -> Tuple[List[str], Sequence[torch.Size], List[int]]: def metadata(self, flat_param_idx: int) -> Tuple[List[str], Sequence[torch.Size], List[int]]:
"""Return metadata for a flat param given its index in the flat_params list.""" """Return metadata for a flat param given its index in the flat_params list."""
return self.flat_params[flat_param_idx].metadata() return self.flat_params[flat_param_idx].metadata() # type: ignore
def _post_state_dict_hook( def _post_state_dict_hook(
......
...@@ -8,6 +8,7 @@ Testing SsdFlatParameter and SsdTensorHandle modules. ...@@ -8,6 +8,7 @@ Testing SsdFlatParameter and SsdTensorHandle modules.
""" """
import filecmp import filecmp
import functools
import os import os
import tempfile import tempfile
...@@ -15,11 +16,12 @@ import numpy as np ...@@ -15,11 +16,12 @@ import numpy as np
import pytest import pytest
import torch import torch
import fairscale.experimental.nn.ssd_offload as so try:
from fairscale.utils import torch_version import fairscale.experimental.nn.ssd_offload as so
except ImportError as ie:
# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release. # Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
pytestmark = pytest.mark.skipif(torch_version() < (1, 11, 0), reason="requires torch version >= 1.11.0") pytestmark = pytest.mark.skipif(True, reason=ie.msg)
pass
def _init(): def _init():
...@@ -77,10 +79,32 @@ def test_ssd_handle_dispatch_bwd(): ...@@ -77,10 +79,32 @@ def test_ssd_handle_dispatch_bwd():
assert torch.equal(ssd_handle.grad, orig_copy.grad) assert torch.equal(ssd_handle.grad, orig_copy.grad)
def test_ssd_handle_train_simple(): def test_ssd_handle_dispatch_bwd_hook():
if torch_version() >= (1, 12, 0): _init()
pytest.skip("to be fixed")
def post_backward_hook(name, grad):
print(f"BACKWARD HOOK for tensor {name} CALLED")
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4), requires_grad=True)
orig_copy = orig_tensor.clone().detach().requires_grad_(True)
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
ssd_handle.set_file_params(f.name, 0)
ssd_handle.to_file(release_tensor_after_write=True)
one = torch.ones((1), requires_grad=True).cuda()
orig_copy = ssd_handle.data
cuda_copy = ssd_handle.to("cuda").detach().requires_grad_(True)
ssd_handle.data = cuda_copy
ssd_handle.register_hook(functools.partial(post_backward_hook, "ssd_handle"))
one.register_hook(functools.partial(post_backward_hook, "one"))
y1 = ssd_handle + one
y1.sum().backward()
def test_ssd_handle_train_simple():
_init() _init()
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
...@@ -92,6 +116,7 @@ def test_ssd_handle_train_simple(): ...@@ -92,6 +116,7 @@ def test_ssd_handle_train_simple():
orig_copy.requires_grad = True orig_copy.requires_grad = True
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor) ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
ssd_handle.flush_on_dirty = False
ssd_handle.set_file_params(f.name, 0) ssd_handle.set_file_params(f.name, 0)
ssd_handle.to_file(release_tensor_after_write=True) ssd_handle.to_file(release_tensor_after_write=True)
...@@ -102,15 +127,15 @@ def test_ssd_handle_train_simple(): ...@@ -102,15 +127,15 @@ def test_ssd_handle_train_simple():
y1 = ssd_handle + 1 y1 = ssd_handle + 1
optimizer_ssd.zero_grad() optimizer_ssd.zero_grad()
y1.sum().backward() y1.sum().backward()
assert ssd_handle.storage_state is so.StorageState.ON_CPU_CLEAN
optimizer_ssd.step() optimizer_ssd.step()
assert ssd_handle.storage_state is so.StorageState.ON_CPU_DIRTY
y2 = orig_copy + 1 y2 = orig_copy + 1
optimizer_orig.zero_grad() optimizer_orig.zero_grad()
y2.sum().backward() y2.sum().backward()
optimizer_orig.step() optimizer_orig.step()
# make sure we are using the file version not the cached tensor
ssd_handle.point_to_file(f.name, 0)
assert torch.equal(ssd_handle.to_tensor(), orig_copy) assert torch.equal(ssd_handle.to_tensor(), orig_copy)
...@@ -169,9 +194,6 @@ def test_torch_save_load_ssd_flat_param_on_mem(): ...@@ -169,9 +194,6 @@ def test_torch_save_load_ssd_flat_param_on_mem():
def test_ssd_param_train_simple(): def test_ssd_param_train_simple():
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
_init() _init()
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4)) orig_tensor = torch.randn((4, 4))
...@@ -183,6 +205,7 @@ def test_ssd_param_train_simple(): ...@@ -183,6 +205,7 @@ def test_ssd_param_train_simple():
ssd_param = so.SsdParameter(orig_tensor.shape, orig_tensor.dtype) ssd_param = so.SsdParameter(orig_tensor.shape, orig_tensor.dtype)
ssd_param.point_to_tensor(orig_copy) ssd_param.point_to_tensor(orig_copy)
ssd_param.flush_on_dirty = False
ssd_param.set_file_params(f.name, 0) ssd_param.set_file_params(f.name, 0)
ssd_param.to_file(release_tensor_after_write=True) ssd_param.to_file(release_tensor_after_write=True)
...@@ -193,15 +216,17 @@ def test_ssd_param_train_simple(): ...@@ -193,15 +216,17 @@ def test_ssd_param_train_simple():
y1 = ssd_param + 1 y1 = ssd_param + 1
optimizer_ssd.zero_grad() optimizer_ssd.zero_grad()
y1.sum().backward() y1.sum().backward()
# Test to see if Dirty is being calculated correctly when optimizer modifies
# ssd_param
assert ssd_param.storage_state is so.StorageState.ON_CPU_CLEAN
optimizer_ssd.step() optimizer_ssd.step()
assert ssd_param.storage_state is so.StorageState.ON_CPU_DIRTY
y2 = param + 1 y2 = param + 1
optimizer_orig.zero_grad() optimizer_orig.zero_grad()
y2.sum().backward() y2.sum().backward()
optimizer_orig.step() optimizer_orig.step()
# make sure we are using the file version not the cached tensor
ssd_param.point_to_file(f.name, 0)
assert torch.equal(ssd_param.to_tensor(), param) assert torch.equal(ssd_param.to_tensor(), param)
...@@ -211,8 +236,186 @@ def test_ssd_flat_parameter_basic(): ...@@ -211,8 +236,186 @@ def test_ssd_flat_parameter_basic():
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32)) refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32)) refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32)) refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32))
ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], False) ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], direct_to_file=False)
ssd_flat_param.set_file_params(f.name, 0)
param_views = list(ssd_flat_param.get_param_views())
assert refa_param.shape == param_views[0].shape
assert refb_param.shape == param_views[1].shape
assert refc_param.shape == param_views[2].shape
assert torch.equal(refa_param, param_views[0])
assert torch.equal(refb_param, param_views[1])
assert torch.equal(refc_param, param_views[2])
ssd_flat_param.to_file()
assert not ssd_flat_param.is_available()
first_value = param_views[0][0][0].item()
assert ssd_flat_param.is_available()
assert first_value == refa_param[0][0].item()
def test_ssd_flat_parameter_view_modify():
_init()
with tempfile.NamedTemporaryFile() as f:
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=False)
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=False)
refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32), requires_grad=False)
ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], direct_to_file=False)
ssd_flat_param.set_file_params(f.name, 0) ssd_flat_param.set_file_params(f.name, 0)
ssd_flat_param.flush_on_dirty = False
param_views = list(ssd_flat_param.get_param_views())
assert ssd_flat_param.storage_state == so.StorageState.ON_CPU_DIRTY
ssd_flat_param.to_file()
assert ssd_flat_param.storage_state == so.StorageState.ON_DISK
assert param_views[0].tensor is None
param_views[0] += 0.1
assert ssd_flat_param.storage_state == so.StorageState.ON_CPU_DIRTY
def test_ssd_flat_parameter_view_bwd():
_init()
hooks_called = []
def post_backward_hook(name, hooks_called, *grads):
print(f"BACKWARD HOOK for tensor {name} CALLED")
hooks_called.append(name)
with tempfile.NamedTemporaryFile() as f:
refa_param = (
torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=True)
.to("cpu")
.detach()
.requires_grad_()
)
refb_param = (
torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=True)
.to("cpu")
.detach()
.requires_grad_()
)
refc_param = (
torch.nn.Parameter(torch.rand((128), dtype=torch.float32), requires_grad=True)
.to("cpu")
.detach()
.requires_grad_()
)
ssd_flat_param = so.SsdFlatParameter.from_tensors(
[refa_param, refb_param, refc_param], direct_to_file=True, filename=f.name, offset=0
)
orig_copy = ssd_flat_param.data
cuda_copy = ssd_flat_param.to("cuda").detach().requires_grad_()
cpu_copy = ssd_flat_param.to("cpu").detach().requires_grad_()
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_orig", hooks_called))
ssd_flat_param.data = cuda_copy
one = torch.ones((1), requires_grad=True, device=ssd_flat_param.device)
y1 = ssd_flat_param.views[0] + one
y2 = cuda_copy + 1
# ssd_flat_param.to_file()
# ssd_flat_param.data = orig_copy
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_cuda", hooks_called))
ssd_flat_param.views[0].register_hook(
functools.partial(post_backward_hook, "ssd_flat_param.views[0]", hooks_called)
)
ssd_flat_param.register_hook(functools.partial(post_backward_hook, "ssd_flat_param", hooks_called))
one.register_hook(functools.partial(post_backward_hook, "one", hooks_called))
y1.sum().backward()
y2.sum().backward()
assert "GradAccumulation_cuda" in hooks_called
assert "ssd_flat_param.views[0]" in hooks_called
assert "ssd_flat_param" in hooks_called
assert "one" in hooks_called
def test_ssd_flat_parameter_view_bwd_parameterization():
_init()
hooks_called = []
def post_backward_hook(name, hooks_called, *grads):
print(f"BACKWARD HOOK for tensor {name} CALLED")
hooks_called.append(name)
with tempfile.NamedTemporaryFile() as f:
layer1 = torch.nn.Linear(32, 4, bias=False)
layer2 = torch.nn.Linear(32, 4, bias=False)
layer3 = torch.nn.Linear(128, 1, bias=False)
ssd_flat_param = so.SsdFlatParameter.from_tensors(
[layer1.weight, layer2.weight, layer3.weight], direct_to_file=False, filename=f.name, offset=0
)
torch.nn.utils.parametrize.register_parametrization(
layer1, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 0)
)
torch.nn.utils.parametrize.register_parametrization(
layer2, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 1)
)
torch.nn.utils.parametrize.register_parametrization(
layer3, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 2)
)
orig_copy = ssd_flat_param.data
cuda_copy = ssd_flat_param.to("cuda").detach().requires_grad_()
cpu_copy = ssd_flat_param.to("cpu").detach().requires_grad_()
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_orig", hooks_called))
ssd_flat_param.to_file(release_tensor_after_write=False)
ssd_flat_param.data = cuda_copy
one = torch.ones(layer1.weight.shape, requires_grad=True, device=ssd_flat_param.device)
y1 = layer1.forward(one)
y2 = cuda_copy + 1
# ssd_flat_param.to_file()
# ssd_flat_param.data = orig_copy
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_cuda", hooks_called))
ssd_flat_param.views[0].register_hook(
functools.partial(post_backward_hook, "ssd_flat_param.views[0]", hooks_called)
)
ssd_flat_param.register_hook(functools.partial(post_backward_hook, "ssd_flat_param", hooks_called))
one.register_hook(functools.partial(post_backward_hook, "one", hooks_called))
y1.sum().backward()
y2.sum().backward()
assert "GradAccumulation_cuda" in hooks_called
assert "ssd_flat_param.views[0]" in hooks_called
assert "ssd_flat_param" in hooks_called
assert "one" in hooks_called
def test_ssd_flat_parameter_direct_to_file():
_init()
with tempfile.NamedTemporaryFile() as f:
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32))
ssd_flat_param = so.SsdFlatParameter.from_tensors(
[refa_param, refb_param, refc_param], direct_to_file=True, filename=f.name, offset=0
)
param_views = list(ssd_flat_param.get_param_views()) param_views = list(ssd_flat_param.get_param_views())
...@@ -224,3 +427,8 @@ def test_ssd_flat_parameter_basic(): ...@@ -224,3 +427,8 @@ def test_ssd_flat_parameter_basic():
assert torch.equal(refb_param, param_views[1]) assert torch.equal(refb_param, param_views[1])
assert torch.equal(refc_param, param_views[2]) assert torch.equal(refc_param, param_views[2])
ssd_flat_param.to_file() ssd_flat_param.to_file()
assert not ssd_flat_param.is_available()
first_value = param_views[0][0][0].item()
assert ssd_flat_param.is_available()
assert first_value == refa_param[0][0].item()
...@@ -16,17 +16,17 @@ import torch ...@@ -16,17 +16,17 @@ import torch
from torch import nn from torch import nn
import torch.distributed import torch.distributed
import fairscale.experimental.nn.ssd_offload as so try:
import fairscale.experimental.nn.ssd_offload as so
except ImportError as ie:
# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
pytestmark = pytest.mark.skipif(True, reason=ie.msg)
pass
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState from fairscale.nn.data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes
# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
print(f"torch version {torch_version()}")
pytestmark = pytest.mark.skipif(torch_version() < (1, 11, 0), reason="requires torch version >= 1.11.0")
# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4 # How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# All helper functions called by spawn must be either @classmethod, @staticmethod # All helper functions called by spawn must be either @classmethod, @staticmethod
...@@ -137,8 +137,6 @@ def rename_test(testcase_func, param_num, param): ...@@ -137,8 +137,6 @@ def rename_test(testcase_func, param_num, param):
class TestSsdMemory(DistributedTest): class TestSsdMemory(DistributedTest):
def test_memory_benchmark(self): def test_memory_benchmark(self):
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
test_fn = functools.partial(self._test_memory_benchmark, config={}) test_fn = functools.partial(self._test_memory_benchmark, config={})
spawn_and_init(test_fn) spawn_and_init(test_fn)
...@@ -218,8 +216,6 @@ class TimeKeeper: ...@@ -218,8 +216,6 @@ class TimeKeeper:
class TestModuleProperties(DistributedTest): class TestModuleProperties(DistributedTest):
@parameterized.expand(CONFIG, name_func=rename_test) @parameterized.expand(CONFIG, name_func=rename_test)
def test_named_parameters(self, config): def test_named_parameters(self, config):
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
test_fn = functools.partial(self._test_named_params, config=config) test_fn = functools.partial(self._test_named_params, config=config)
spawn_and_init(test_fn) spawn_and_init(test_fn)
...@@ -264,23 +260,17 @@ class TestModuleProperties(DistributedTest): ...@@ -264,23 +260,17 @@ class TestModuleProperties(DistributedTest):
class TestSsdLoading(DistributedTest): class TestSsdLoading(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_ssd_offloading_eval(self, config): def test_ssd_offloading_eval(self, config):
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
test_fn = functools.partial(self._test_ssd_offload_eval, config=config) test_fn = functools.partial(self._test_ssd_offload_eval, config=config)
spawn_and_init(test_fn) spawn_and_init(test_fn)
@parameterized.expand(CONFIG, name_func=rename_test) @parameterized.expand(CONFIG, name_func=rename_test)
def test_transformer_parameterized(self, config): def test_transformer_parameterized(self, config):
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
spawn_and_init(functools.partial(self._test_identical_outputs_eval, TransformerWithSharedParams, config)) spawn_and_init(functools.partial(self._test_identical_outputs_eval, TransformerWithSharedParams, config))
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_ssd_offloading_train_flatten_params_wrapper(self, config): def test_ssd_offloading_train_flatten_params_wrapper(self, config):
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
test_fn = functools.partial(self._test_ssd_offloading_train_flatten_params_wrapper, config=config) test_fn = functools.partial(self._test_ssd_offloading_train_flatten_params_wrapper, config=config)
spawn_and_init(test_fn) spawn_and_init(test_fn)
...@@ -288,6 +278,8 @@ class TestSsdLoading(DistributedTest): ...@@ -288,6 +278,8 @@ class TestSsdLoading(DistributedTest):
@classmethod @classmethod
def _test_ssd_offloading_train_flatten_params_wrapper(self, rank, group, config): def _test_ssd_offloading_train_flatten_params_wrapper(self, rank, group, config):
SIZE = 16 * 16 SIZE = 16 * 16
LR = 0.01
MOMENTUM = 0.1
model = SimpleLinear(group, input_size=SIZE, output_size=SIZE, layers=4) model = SimpleLinear(group, input_size=SIZE, output_size=SIZE, layers=4)
with tempfile.TemporaryDirectory() as current_tempdir: with tempfile.TemporaryDirectory() as current_tempdir:
...@@ -305,7 +297,7 @@ class TestSsdLoading(DistributedTest): ...@@ -305,7 +297,7 @@ class TestSsdLoading(DistributedTest):
model = FullyShardedDataParallel(model, **config) model = FullyShardedDataParallel(model, **config)
model_device = torch.device("cuda") model_device = torch.device("cuda")
model.train() model.train()
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) optim = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
checkpoint_file = tempfile.NamedTemporaryFile() checkpoint_file = tempfile.NamedTemporaryFile()
checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir") checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir")
...@@ -322,9 +314,13 @@ class TestSsdLoading(DistributedTest): ...@@ -322,9 +314,13 @@ class TestSsdLoading(DistributedTest):
input = model.get_input(torch.device("cuda")) input = model.get_input(torch.device("cuda"))
output = model(*input) output = model(*input)
pre_checkpoint_last_output = output pre_checkpoint_last_output = output
"""
param_itr = iter(model.named_parameters())
p_name, p_val = next(param_itr)
print(f"i={i} pre_checkpoint {p_name} = {p_val[0].item()}")
"""
loss = model.module.get_loss(input, output).to(model_device) loss = model.module.get_loss(input, output).to(model_device)
assert loss.dtype == torch.float32 assert loss.dtype == torch.float32
model.module.run_backward(loss) model.module.run_backward(loss)
optim.step() optim.step()
if i == 0: if i == 0:
...@@ -332,18 +328,23 @@ class TestSsdLoading(DistributedTest): ...@@ -332,18 +328,23 @@ class TestSsdLoading(DistributedTest):
# so.torch_saver.save({"model": model.state_dict(), "optim": optim.state_dict()}, checkpoint_file.name) # so.torch_saver.save({"model": model.state_dict(), "optim": optim.state_dict()}, checkpoint_file.name)
torch.save({"model": model.state_dict()}, checkpoint_file.name) torch.save({"model": model.state_dict()}, checkpoint_file.name)
# reset momentum just after checkpoint save # reset momentum just after checkpoint save
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) optim = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
checkpoint = torch.load(checkpoint_file.name) checkpoint = torch.load(checkpoint_file.name)
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])
# reset momentum just after checkpoint load # reset momentum just after checkpoint load
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) optim = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
# do more iterations after loading checkpoint # do more iterations after loading checkpoint
for i in range(ITERATIONS - 1): for i in range(ITERATIONS - 1):
optim.zero_grad() optim.zero_grad()
input = model.get_input(torch.device("cuda")) input = model.get_input(torch.device("cuda"))
output = model(*input) output = model(*input)
post_checkpoint_last_output = output post_checkpoint_last_output = output
"""
param_itr = iter(model.named_parameters())
p_name, p_val = next(param_itr)
print(f"i={i} post_checkpoint {p_name} = {p_val[0].item()}")
"""
loss = model.module.get_loss(input, output).to(model_device) loss = model.module.get_loss(input, output).to(model_device)
assert loss.dtype == torch.float32 assert loss.dtype == torch.float32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment