Unverified Commit 51b53ddb authored by Paul Johnson's avatar Paul Johnson Committed by GitHub
Browse files

[FSDP] ssd_offload fixing backward path (grad_fn) for SsdFlatParameter and...

[FSDP] ssd_offload fixing backward path (grad_fn) for SsdFlatParameter and SsdFlatParameterView (#974)

* [FSDP] fixing backward path for SsdFlatParameter and SsdFlatParameterView when overriding .data

* Get ssd_offload unit tests passing

* [FSDP] get all test_fsdp_offload tests passing w/ ssd_offload on

* Update changelog
parent 58ccb166
......@@ -12,7 +12,40 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- FSDP: Add pickle/unpickle support for SsdTensorHandle (and derived classes),
verified that FSDP models w/ ssd_offload enabled can correctly call model.state_dict()
and model.load_state_dict(...) and thus successfully checkpoint and recover parameters
stored as SsdFlatParameters.
stored as SsdFlatParameters. [#964]
- FSDP Ssd Offload: [#974]
* add __setattr__ function for SsdTensorHandle/SsdFlatParameter to capture when .data is overridden
and perform necessary checks/updates before letting the tensor metadata be updated.
* There was a bug in pytorch core that caused re-assigning TensorSubclass .data to
disable the __torch_dispatch__ mechanism and effectively revert back to a normal
tensor. Since ssd_offload feature uses __torch_dispatch__ extensively and FSDP overrides
.data, ssd_offload now can only be used if pytorch version is 1.12.0 or later
(currently pytorch-nightly release). pytest tests are disabled and trying to import
ssd_offload or use FSDP with ssd_offload enabled will raise ImportError Exceptions.
* Enhance storage_state ON_CPU into ON_CPU_CLEAN and ON_CPU_DIRTY. ON_CPU_CLEAN
indicates that the .tensor value and the value stored on disk are identical, and no
writes are needed when calling .to_file(). ON_CPU_DIRTY indicates .tensor does not
match value stored on disk, and a write is necessary. This is disabled in FSDP as
it does not currently flush variables to disk as soon as optimizer modifies values.
* Fix detection in SsdTensorHandle __torch_dispatch__ when it is used in an in-place
operator, or as an output, then calling mark_dirty().
* Fix grad_fn on SsdFlatParameterViews so that both ssd_flat_param.view.grad_fn points
to a split and view on ssd_flat_parameter so that when X.backwards() is called. ssd_flat_param.grad
is properly updated in the backward pass.
* Add unit tests to verify grad_fn and backward hooks are called appropriately on
SsdFlatParameters/SsdFlatParameterViews
* Implement SsdFlatParameter.from_tensor() direct_to_file option. This allows creating
and SsdFlatParameter and directly writing it to disk, rather than creating a new tensor
first. This prevents another copy of tensors to be instantiated and can save memory when
creating very large SsdFlatParameters.
* Add SsdFlatParameterViewProperty for overriding a parameter in an nn.Module
similar to pytorch core's parameterization code path. This allows SsdFlatParameterViews
to be updated as SsdFlatParameter.data is overridden by replacing Layer.weight with a
property whose getter returns ssd_flat_param.views[view_id].
* Added example SsdFlatParameterViewParameterization on how the existing parameterization
method could be used instead of SsdFlatParameterViewProperty. But this method will result
in memory inefficiencies due to parameterization always keeping a copy of the original
parameter internally.
### Fixed
......
This diff is collapsed.
......@@ -66,7 +66,6 @@ else:
try:
import fairscale.experimental.nn.ssd_offload as ssd_offload
from fairscale.experimental.nn.ssd_offload import SsdFlatParameter
import_ssd_offload = True
except ImportError:
......@@ -399,6 +398,10 @@ class FullyShardedDataParallel(nn.Module):
self.state_dict_on_rank_0_only = state_dict_on_rank_0_only
# Experimental feature for now. Use at your own risk.
self.ssd_offload = True if offload_config and offload_config.offload_type == "ssd_offload" else False
if self.ssd_offload and not import_ssd_offload:
raise ImportError(
f"Trying to enable ssd_offload when it was not successfully imported (likely due to old torch version, current = {torch.__version__})"
)
self.gradient_predivide_factor: float = gradient_predivide_factor or self._get_gradient_predivide_factor(
self.world_size
......@@ -761,16 +764,14 @@ class FullyShardedDataParallel(nn.Module):
p._is_sharded = True
# Replace p.data with the relevant shard.
if self.ssd_offload:
assert isinstance(p, SsdFlatParameter)
sharded_tensor, num_padded = self._get_shard(p.data)
p.point_to_resized_tensor(sharded_tensor)
self.numel_padded_per_param.append(num_padded)
p.to_file()
else:
orig_data = p.data
p.data, num_padded = self._get_shard(p.data)
self.numel_padded_per_param.append(num_padded)
if self.ssd_offload:
assert isinstance(p, ssd_offload.SsdParameter)
p.to_file()
else:
free_storage_(orig_data)
assert len(self.numel_padded_per_param) == len(self.params)
......@@ -987,7 +988,7 @@ class FullyShardedDataParallel(nn.Module):
def _move_params_to_memory(self) -> None:
"""Move params from disk to CPU."""
for p in self.params:
assert isinstance(p, SsdFlatParameter)
assert isinstance(p, ssd_offload.SsdParameter)
p.to_tensor()
def _load_state_dict(
......@@ -1138,6 +1139,11 @@ class FullyShardedDataParallel(nn.Module):
# Copy any changes made to the full params back into
# the corresponding local shards.
local_shard, _ = self._get_shard(full_tensor)
if self.ssd_offload:
assert isinstance(p, ssd_offload.SsdParameter)
self._ssd_offload_reset_param_device(p)
p.point_to_tensor(local_shard.view_as(p._fp32_shard).cpu())
else:
p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
if safe_to_free:
free_storage_(full_tensor)
......@@ -1145,7 +1151,7 @@ class FullyShardedDataParallel(nn.Module):
if self.ssd_offload:
# Store tensors in the SSD buffer and free param storage.
for p in self.params:
assert isinstance(p, SsdFlatParameter)
assert isinstance(p, ssd_offload.SsdParameter)
p.to_file()
else:
self._use_fp32_param_shard()
......@@ -1281,10 +1287,11 @@ class FullyShardedDataParallel(nn.Module):
# shard in pinned memory so that we can do a non-blocking transfer.
# This is only needed during training and not evaluation.
if self.ssd_offload:
assert isinstance(p, SsdFlatParameter)
assert isinstance(p, ssd_offload.SsdParameter)
# Gradients also need to be offloaded to SSD otherwise it can result in
# OOMs when the memory requirements of a model are larger than host memory.
p._cpu_grad = ssd_offload.SsdTensorHandle.from_tensor(torch.zeros_like(p.data, device="cpu"))
p._cpu_grad.allow_unsafe_changes = True
p._cpu_grad.set_file_params(p.filename + "_grad", 0)
p._cpu_grad.to_file()
else:
......@@ -1441,7 +1448,7 @@ class FullyShardedDataParallel(nn.Module):
def _free_ssd_offload(self) -> None:
if self.ssd_offload:
for p in self.params:
assert isinstance(p, SsdFlatParameter)
assert isinstance(p, ssd_offload.SsdParameter)
p.to_file(permit_when_tensor_none=True)
def _register_pre_backward_hooks(self, outputs: Any) -> Any:
......@@ -1897,7 +1904,9 @@ class FullyShardedDataParallel(nn.Module):
if self.ssd_offload:
for p in self.params:
assert isinstance(p, SsdFlatParameter)
assert isinstance(p, ssd_offload.SsdParameter)
if not p.is_available():
self._ssd_offload_reset_param_device(p)
p.to_tensor()
self.has_full_params = False
......@@ -2175,12 +2184,24 @@ class FullyShardedDataParallel(nn.Module):
return consolidated_weights
@torch.no_grad()
def _ssd_offload_reset_param_device(self, param: Parameter) -> None:
assert isinstance(param, ssd_offload.SsdParameter)
if param.device != torch.device("cpu"):
param.data = param._fp32_shard
param.tensor = None
@torch.no_grad()
def _use_fp32_param_shard(self, params: Optional[List[Parameter]] = None) -> None:
"""Use FP32 shard for a list of params."""
if params is None:
params = self.params
for p in params:
if import_ssd_offload and self.ssd_offload:
assert isinstance(p, ssd_offload.SsdParameter)
self._ssd_offload_reset_param_device(p)
p.to_tensor()
else:
p.data = p._fp32_shard
@torch.no_grad()
......@@ -2192,6 +2213,9 @@ class FullyShardedDataParallel(nn.Module):
for p in params:
assert p._fp16_shard is not None
alloc_storage_(p._fp16_shard, size=p._fp32_shard.size())
if self.ssd_offload:
p._fp16_shard.copy_(p.to(p._fp16_shard.device, non_blocking=True))
else:
p._fp16_shard.copy_(
# If move_params_to_cpu is True, this will be non-blocking
# because _fp32_shard is pinned, otherwise it's a no-op.
......
......@@ -31,7 +31,19 @@ import torch
from torch import Tensor
import torch.nn as nn
from fairscale.experimental.nn.ssd_offload import SsdFlatParameter
try:
from fairscale.experimental.nn.ssd_offload import (
SsdFlatParameter,
SsdFlatParameterView,
SsdFlatParameterViewProperty,
_register_property,
)
import_ssd_offload = True
except ImportError:
import_ssd_offload = False
pass
from fairscale.utils.state_dict import replace_by_prefix_
if TYPE_CHECKING:
......@@ -116,7 +128,6 @@ class FlatParameter(nn.Parameter):
# Static types.
FlatTypes = Union[FlatParameter, SsdFlatParameter]
ParamGroups = Optional[Union[List[List[nn.Parameter]], List[nn.Parameter]]]
......@@ -159,6 +170,11 @@ class FlattenParamsWrapper(nn.Module):
ssd_directory: str = "",
):
super().__init__()
if ssd_offload and not import_ssd_offload:
raise ImportError(
f"Trying to enable ssd_offload when it was not successfully imported (likely due to old torch version, current = {torch.__version__})"
)
self.ssd_offload = ssd_offload
self._fpw_module = module
self.is_flattened = False
......@@ -205,7 +221,7 @@ class FlattenParamsWrapper(nn.Module):
# support.
raise ValueError(f"Incorrect param groups {len(overall_param_set)} vs {self.num_param_managed}")
self.flat_params: List[FlatTypes] = []
self.flat_params: List[nn.Parameter] = []
# Prepare flat param names.
if flat_param_names is None:
......@@ -215,7 +231,7 @@ class FlattenParamsWrapper(nn.Module):
if len(flat_param_names) != len(set(flat_param_names)):
raise ValueError("Each flat param must be given a unique name")
self.flat_param_names = [f"flat_param_{n}" for n in flat_param_names]
flat_param: Optional[FlatTypes] = None
flat_param: Optional[nn.Parameter] = None
# Init all flat_params.
for new_p_set in self._param_sets:
......@@ -224,6 +240,7 @@ class FlattenParamsWrapper(nn.Module):
assert ssd_directory != ""
(handle, fname) = tempfile.mkstemp(dir=ssd_directory, suffix="ssd_buf_param")
flat_param = SsdFlatParameter.from_tensors(tensors=params)
flat_param.allow_unsafe_changes = True
flat_param.set_file_params(fname, 0)
else:
flat_param = FlatParameter(params, params[0].requires_grad)
......@@ -309,13 +326,13 @@ class FlattenParamsWrapper(nn.Module):
@property
def _param_infos(self) -> Iterator[Tuple[str, nn.Module, str]]:
return chain(*[p._param_infos for p in self.flat_params])
return chain(*[p._param_infos for p in self.flat_params]) # type: ignore
@property
def _shared_param_infos(self) -> Iterator[Tuple[str, str, nn.Module, str, nn.Module, str]]:
return chain(*[p._shared_param_infos for p in self.flat_params])
return chain(*[p._shared_param_infos for p in self.flat_params]) # type: ignore
def _flatten_params(self, flat_params: List[FlatTypes]) -> None:
def _flatten_params(self, flat_params: List[nn.Parameter]) -> None:
"""Flatten the managed parameters and replaced the original
attributes with views to the flat params.
"""
......@@ -331,6 +348,7 @@ class FlattenParamsWrapper(nn.Module):
# deregister the names as parameters
for _, m, n in self._param_infos:
delattr(m, n)
for _, _, m, n, _, _ in self._shared_param_infos:
delattr(m, n)
......@@ -372,6 +390,11 @@ class FlattenParamsWrapper(nn.Module):
ps = self.get_param_views()
param_views = []
for (_, m, n), p in zip(self._param_infos, ps):
if self.ssd_offload:
assert isinstance(p, SsdFlatParameterView)
_register_property(m, n, SsdFlatParameterViewProperty(p.parent, p.id))
else:
setattr(m, n, p) # This will set as plain attr
param_views.append(p)
......@@ -498,13 +521,13 @@ class FlattenParamsWrapper(nn.Module):
gens = []
for p, data in zip(params, external_data_list):
gens.append(p.get_param_views(data))
gens.append(p.get_param_views(data)) # type: ignore
return chain(*gens)
def metadata(self, flat_param_idx: int) -> Tuple[List[str], Sequence[torch.Size], List[int]]:
"""Return metadata for a flat param given its index in the flat_params list."""
return self.flat_params[flat_param_idx].metadata()
return self.flat_params[flat_param_idx].metadata() # type: ignore
def _post_state_dict_hook(
......
......@@ -8,6 +8,7 @@ Testing SsdFlatParameter and SsdTensorHandle modules.
"""
import filecmp
import functools
import os
import tempfile
......@@ -15,11 +16,12 @@ import numpy as np
import pytest
import torch
import fairscale.experimental.nn.ssd_offload as so
from fairscale.utils import torch_version
# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
pytestmark = pytest.mark.skipif(torch_version() < (1, 11, 0), reason="requires torch version >= 1.11.0")
try:
import fairscale.experimental.nn.ssd_offload as so
except ImportError as ie:
# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
pytestmark = pytest.mark.skipif(True, reason=ie.msg)
pass
def _init():
......@@ -77,10 +79,32 @@ def test_ssd_handle_dispatch_bwd():
assert torch.equal(ssd_handle.grad, orig_copy.grad)
def test_ssd_handle_train_simple():
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
def test_ssd_handle_dispatch_bwd_hook():
_init()
def post_backward_hook(name, grad):
print(f"BACKWARD HOOK for tensor {name} CALLED")
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4), requires_grad=True)
orig_copy = orig_tensor.clone().detach().requires_grad_(True)
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
ssd_handle.set_file_params(f.name, 0)
ssd_handle.to_file(release_tensor_after_write=True)
one = torch.ones((1), requires_grad=True).cuda()
orig_copy = ssd_handle.data
cuda_copy = ssd_handle.to("cuda").detach().requires_grad_(True)
ssd_handle.data = cuda_copy
ssd_handle.register_hook(functools.partial(post_backward_hook, "ssd_handle"))
one.register_hook(functools.partial(post_backward_hook, "one"))
y1 = ssd_handle + one
y1.sum().backward()
def test_ssd_handle_train_simple():
_init()
with tempfile.NamedTemporaryFile() as f:
......@@ -92,6 +116,7 @@ def test_ssd_handle_train_simple():
orig_copy.requires_grad = True
ssd_handle = so.SsdTensorHandle.from_tensor(orig_tensor)
ssd_handle.flush_on_dirty = False
ssd_handle.set_file_params(f.name, 0)
ssd_handle.to_file(release_tensor_after_write=True)
......@@ -102,15 +127,15 @@ def test_ssd_handle_train_simple():
y1 = ssd_handle + 1
optimizer_ssd.zero_grad()
y1.sum().backward()
assert ssd_handle.storage_state is so.StorageState.ON_CPU_CLEAN
optimizer_ssd.step()
assert ssd_handle.storage_state is so.StorageState.ON_CPU_DIRTY
y2 = orig_copy + 1
optimizer_orig.zero_grad()
y2.sum().backward()
optimizer_orig.step()
# make sure we are using the file version not the cached tensor
ssd_handle.point_to_file(f.name, 0)
assert torch.equal(ssd_handle.to_tensor(), orig_copy)
......@@ -169,9 +194,6 @@ def test_torch_save_load_ssd_flat_param_on_mem():
def test_ssd_param_train_simple():
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
_init()
with tempfile.NamedTemporaryFile() as f:
orig_tensor = torch.randn((4, 4))
......@@ -183,6 +205,7 @@ def test_ssd_param_train_simple():
ssd_param = so.SsdParameter(orig_tensor.shape, orig_tensor.dtype)
ssd_param.point_to_tensor(orig_copy)
ssd_param.flush_on_dirty = False
ssd_param.set_file_params(f.name, 0)
ssd_param.to_file(release_tensor_after_write=True)
......@@ -193,15 +216,17 @@ def test_ssd_param_train_simple():
y1 = ssd_param + 1
optimizer_ssd.zero_grad()
y1.sum().backward()
# Test to see if Dirty is being calculated correctly when optimizer modifies
# ssd_param
assert ssd_param.storage_state is so.StorageState.ON_CPU_CLEAN
optimizer_ssd.step()
assert ssd_param.storage_state is so.StorageState.ON_CPU_DIRTY
y2 = param + 1
optimizer_orig.zero_grad()
y2.sum().backward()
optimizer_orig.step()
# make sure we are using the file version not the cached tensor
ssd_param.point_to_file(f.name, 0)
assert torch.equal(ssd_param.to_tensor(), param)
......@@ -211,8 +236,186 @@ def test_ssd_flat_parameter_basic():
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32))
ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], False)
ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], direct_to_file=False)
ssd_flat_param.set_file_params(f.name, 0)
param_views = list(ssd_flat_param.get_param_views())
assert refa_param.shape == param_views[0].shape
assert refb_param.shape == param_views[1].shape
assert refc_param.shape == param_views[2].shape
assert torch.equal(refa_param, param_views[0])
assert torch.equal(refb_param, param_views[1])
assert torch.equal(refc_param, param_views[2])
ssd_flat_param.to_file()
assert not ssd_flat_param.is_available()
first_value = param_views[0][0][0].item()
assert ssd_flat_param.is_available()
assert first_value == refa_param[0][0].item()
def test_ssd_flat_parameter_view_modify():
_init()
with tempfile.NamedTemporaryFile() as f:
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=False)
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=False)
refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32), requires_grad=False)
ssd_flat_param = so.SsdFlatParameter.from_tensors([refa_param, refb_param, refc_param], direct_to_file=False)
ssd_flat_param.set_file_params(f.name, 0)
ssd_flat_param.flush_on_dirty = False
param_views = list(ssd_flat_param.get_param_views())
assert ssd_flat_param.storage_state == so.StorageState.ON_CPU_DIRTY
ssd_flat_param.to_file()
assert ssd_flat_param.storage_state == so.StorageState.ON_DISK
assert param_views[0].tensor is None
param_views[0] += 0.1
assert ssd_flat_param.storage_state == so.StorageState.ON_CPU_DIRTY
def test_ssd_flat_parameter_view_bwd():
_init()
hooks_called = []
def post_backward_hook(name, hooks_called, *grads):
print(f"BACKWARD HOOK for tensor {name} CALLED")
hooks_called.append(name)
with tempfile.NamedTemporaryFile() as f:
refa_param = (
torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=True)
.to("cpu")
.detach()
.requires_grad_()
)
refb_param = (
torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32), requires_grad=True)
.to("cpu")
.detach()
.requires_grad_()
)
refc_param = (
torch.nn.Parameter(torch.rand((128), dtype=torch.float32), requires_grad=True)
.to("cpu")
.detach()
.requires_grad_()
)
ssd_flat_param = so.SsdFlatParameter.from_tensors(
[refa_param, refb_param, refc_param], direct_to_file=True, filename=f.name, offset=0
)
orig_copy = ssd_flat_param.data
cuda_copy = ssd_flat_param.to("cuda").detach().requires_grad_()
cpu_copy = ssd_flat_param.to("cpu").detach().requires_grad_()
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_orig", hooks_called))
ssd_flat_param.data = cuda_copy
one = torch.ones((1), requires_grad=True, device=ssd_flat_param.device)
y1 = ssd_flat_param.views[0] + one
y2 = cuda_copy + 1
# ssd_flat_param.to_file()
# ssd_flat_param.data = orig_copy
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_cuda", hooks_called))
ssd_flat_param.views[0].register_hook(
functools.partial(post_backward_hook, "ssd_flat_param.views[0]", hooks_called)
)
ssd_flat_param.register_hook(functools.partial(post_backward_hook, "ssd_flat_param", hooks_called))
one.register_hook(functools.partial(post_backward_hook, "one", hooks_called))
y1.sum().backward()
y2.sum().backward()
assert "GradAccumulation_cuda" in hooks_called
assert "ssd_flat_param.views[0]" in hooks_called
assert "ssd_flat_param" in hooks_called
assert "one" in hooks_called
def test_ssd_flat_parameter_view_bwd_parameterization():
_init()
hooks_called = []
def post_backward_hook(name, hooks_called, *grads):
print(f"BACKWARD HOOK for tensor {name} CALLED")
hooks_called.append(name)
with tempfile.NamedTemporaryFile() as f:
layer1 = torch.nn.Linear(32, 4, bias=False)
layer2 = torch.nn.Linear(32, 4, bias=False)
layer3 = torch.nn.Linear(128, 1, bias=False)
ssd_flat_param = so.SsdFlatParameter.from_tensors(
[layer1.weight, layer2.weight, layer3.weight], direct_to_file=False, filename=f.name, offset=0
)
torch.nn.utils.parametrize.register_parametrization(
layer1, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 0)
)
torch.nn.utils.parametrize.register_parametrization(
layer2, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 1)
)
torch.nn.utils.parametrize.register_parametrization(
layer3, "weight", so.SsdFlatParameterViewParameterization(ssd_flat_param, 2)
)
orig_copy = ssd_flat_param.data
cuda_copy = ssd_flat_param.to("cuda").detach().requires_grad_()
cpu_copy = ssd_flat_param.to("cpu").detach().requires_grad_()
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_orig", hooks_called))
ssd_flat_param.to_file(release_tensor_after_write=False)
ssd_flat_param.data = cuda_copy
one = torch.ones(layer1.weight.shape, requires_grad=True, device=ssd_flat_param.device)
y1 = layer1.forward(one)
y2 = cuda_copy + 1
# ssd_flat_param.to_file()
# ssd_flat_param.data = orig_copy
p_tmp = ssd_flat_param.expand_as(ssd_flat_param) # Get a grad_fn on p_tmp.
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
grad_acc.register_hook(functools.partial(post_backward_hook, "GradAccumulation_cuda", hooks_called))
ssd_flat_param.views[0].register_hook(
functools.partial(post_backward_hook, "ssd_flat_param.views[0]", hooks_called)
)
ssd_flat_param.register_hook(functools.partial(post_backward_hook, "ssd_flat_param", hooks_called))
one.register_hook(functools.partial(post_backward_hook, "one", hooks_called))
y1.sum().backward()
y2.sum().backward()
assert "GradAccumulation_cuda" in hooks_called
assert "ssd_flat_param.views[0]" in hooks_called
assert "ssd_flat_param" in hooks_called
assert "one" in hooks_called
def test_ssd_flat_parameter_direct_to_file():
_init()
with tempfile.NamedTemporaryFile() as f:
refa_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refb_param = torch.nn.Parameter(torch.rand((32, 4), dtype=torch.float32))
refc_param = torch.nn.Parameter(torch.rand((128), dtype=torch.float32))
ssd_flat_param = so.SsdFlatParameter.from_tensors(
[refa_param, refb_param, refc_param], direct_to_file=True, filename=f.name, offset=0
)
param_views = list(ssd_flat_param.get_param_views())
......@@ -224,3 +427,8 @@ def test_ssd_flat_parameter_basic():
assert torch.equal(refb_param, param_views[1])
assert torch.equal(refc_param, param_views[2])
ssd_flat_param.to_file()
assert not ssd_flat_param.is_available()
first_value = param_views[0][0][0].item()
assert ssd_flat_param.is_available()
assert first_value == refa_param[0][0].item()
......@@ -16,17 +16,17 @@ import torch
from torch import nn
import torch.distributed
import fairscale.experimental.nn.ssd_offload as so
try:
import fairscale.experimental.nn.ssd_offload as so
except ImportError as ie:
# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
pytestmark = pytest.mark.skipif(True, reason=ie.msg)
pass
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes
# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
print(f"torch version {torch_version()}")
pytestmark = pytest.mark.skipif(torch_version() < (1, 11, 0), reason="requires torch version >= 1.11.0")
# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# All helper functions called by spawn must be either @classmethod, @staticmethod
......@@ -137,8 +137,6 @@ def rename_test(testcase_func, param_num, param):
class TestSsdMemory(DistributedTest):
def test_memory_benchmark(self):
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
test_fn = functools.partial(self._test_memory_benchmark, config={})
spawn_and_init(test_fn)
......@@ -218,8 +216,6 @@ class TimeKeeper:
class TestModuleProperties(DistributedTest):
@parameterized.expand(CONFIG, name_func=rename_test)
def test_named_parameters(self, config):
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
test_fn = functools.partial(self._test_named_params, config=config)
spawn_and_init(test_fn)
......@@ -264,23 +260,17 @@ class TestModuleProperties(DistributedTest):
class TestSsdLoading(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_ssd_offloading_eval(self, config):
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
test_fn = functools.partial(self._test_ssd_offload_eval, config=config)
spawn_and_init(test_fn)
@parameterized.expand(CONFIG, name_func=rename_test)
def test_transformer_parameterized(self, config):
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
spawn_and_init(functools.partial(self._test_identical_outputs_eval, TransformerWithSharedParams, config))
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_ssd_offloading_train_flatten_params_wrapper(self, config):
if torch_version() >= (1, 12, 0):
pytest.skip("to be fixed")
test_fn = functools.partial(self._test_ssd_offloading_train_flatten_params_wrapper, config=config)
spawn_and_init(test_fn)
......@@ -288,6 +278,8 @@ class TestSsdLoading(DistributedTest):
@classmethod
def _test_ssd_offloading_train_flatten_params_wrapper(self, rank, group, config):
SIZE = 16 * 16
LR = 0.01
MOMENTUM = 0.1
model = SimpleLinear(group, input_size=SIZE, output_size=SIZE, layers=4)
with tempfile.TemporaryDirectory() as current_tempdir:
......@@ -305,7 +297,7 @@ class TestSsdLoading(DistributedTest):
model = FullyShardedDataParallel(model, **config)
model_device = torch.device("cuda")
model.train()
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optim = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
checkpoint_file = tempfile.NamedTemporaryFile()
checkpoint_load_directory = tempfile.TemporaryDirectory(prefix="checkpoint_dir")
......@@ -322,9 +314,13 @@ class TestSsdLoading(DistributedTest):
input = model.get_input(torch.device("cuda"))
output = model(*input)
pre_checkpoint_last_output = output
"""
param_itr = iter(model.named_parameters())
p_name, p_val = next(param_itr)
print(f"i={i} pre_checkpoint {p_name} = {p_val[0].item()}")
"""
loss = model.module.get_loss(input, output).to(model_device)
assert loss.dtype == torch.float32
model.module.run_backward(loss)
optim.step()
if i == 0:
......@@ -332,18 +328,23 @@ class TestSsdLoading(DistributedTest):
# so.torch_saver.save({"model": model.state_dict(), "optim": optim.state_dict()}, checkpoint_file.name)
torch.save({"model": model.state_dict()}, checkpoint_file.name)
# reset momentum just after checkpoint save
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optim = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
checkpoint = torch.load(checkpoint_file.name)
model.load_state_dict(checkpoint["model"])
# reset momentum just after checkpoint load
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optim = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
# do more iterations after loading checkpoint
for i in range(ITERATIONS - 1):
optim.zero_grad()
input = model.get_input(torch.device("cuda"))
output = model(*input)
post_checkpoint_last_output = output
"""
param_itr = iter(model.named_parameters())
p_name, p_val = next(param_itr)
print(f"i={i} post_checkpoint {p_name} = {p_val[0].item()}")
"""
loss = model.module.get_loss(input, output).to(model_device)
assert loss.dtype == torch.float32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment