Unverified Commit 77f92b38 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat]: add summon_full_params context mgr (#433)

* [feat]: add summon_full_params context mgr

* fix

* fix

* addressed comments

* fixed the state_dict copy

* lint
parent f7813d6d
......@@ -51,6 +51,7 @@ class TrainingState(Enum):
IDLE = auto()
FORWARD = auto()
BACKWARD = auto()
SUMMON_FULL_PARAMS = auto()
class FullyShardedDataParallel(nn.Module):
......@@ -383,17 +384,18 @@ class FullyShardedDataParallel(nn.Module):
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
torch.cuda.synchronize()
self._lazy_init()
self._rebuild_full_params()
self._all_buffers_to(dtype=torch.float32) # Buffers dtype stays consistent with parameters.
state_dict = self.module.state_dict(*args, **kwargs)
# We don't free the params after generating the state dict, since
# freeing is done in-place (via the Storage) and would corrupt the
# returned state dict. However, we need to maintain the invariant that
# p.data corresponds to the FP32 param shard, so we do that here.
self._use_fp32_param_shard()
self._all_buffers_to(dtype=self.compute_dtype)
with self.summon_full_params():
# Buffers dtype stays consistent with parameters.
self._all_buffers_to(dtype=torch.float32)
state_dict = self.module.state_dict(*args, **kwargs)
# We copy the state_dict since full param will be freed after
# we exit the summon_full_params() context.
for key in state_dict.keys():
state_dict[key] = state_dict[key].clone()
# In case we are in mixed precision, restore buffers back to fp16.
self._all_buffers_to(dtype=self.compute_dtype)
return state_dict
# TODO (Min): figuring out how to do typing for this overloaded function.
......@@ -419,11 +421,8 @@ class FullyShardedDataParallel(nn.Module):
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
torch.cuda.synchronize()
self._lazy_init()
self._rebuild_full_params()
output = self.module.load_state_dict(state_dict, strict)
self._free_full_params()
with self.summon_full_params():
output = self.module.load_state_dict(state_dict, strict)
return output
def load_local_state_dict(
......@@ -457,6 +456,32 @@ class FullyShardedDataParallel(nn.Module):
for m, old_flag in old_flags:
m.require_backward_grad_sync = old_flag
@contextlib.contextmanager
def summon_full_params(self) -> Generator:
"""
A context manager to expose full params for the underlying model.
Can be useful *after* forward/backward for a model to get the params
for additional processing or checking.
This can be used on inner FSDPs.
This can *not* be used within a forward or backward pass. Nor can forward
and backward be started from within this context.
"""
torch.cuda.synchronize()
self._lazy_init()
self.assert_state(TrainingState.IDLE)
# Set the state so that we assert when trying to go into
# forward/backward.
self.training_state = TrainingState.SUMMON_FULL_PARAMS
self._rebuild_full_params()
try:
yield
finally:
self._free_full_params()
self._use_fp32_param_shard()
self.training_state = TrainingState.IDLE
def _reset_lazy_init(self) -> None:
"""Reset instance so :func:`_lazy_init` will run on the next forward."""
self._is_root: Optional[bool] = None
......@@ -815,8 +840,11 @@ class FullyShardedDataParallel(nn.Module):
if self.move_grads_to_cpu:
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch.cuda.current_stream().synchronize()
# A backward pass is done.
self.training_state = TrainingState.IDLE
# A backward pass is done, update root and nested FSDP's flags.
for m in self.modules(): # includes self
if isinstance(m, FullyShardedDataParallel):
m.assert_state(TrainingState.BACKWARD)
m.training_state = TrainingState.IDLE
@torch.no_grad()
def _rebuild_full_params(self) -> None:
......@@ -851,6 +879,10 @@ class FullyShardedDataParallel(nn.Module):
@torch.no_grad()
def _use_full_params(self) -> None:
"""Switching p.data pointers to use the full params.
Note: this is used assuming full param gathering is already done.
"""
for p in self.params:
if not p._is_sharded:
if self.mixed_precision:
......
......@@ -3,6 +3,7 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch
......@@ -16,6 +17,11 @@ def apply_to_tensors(fn: Callable, container: Union[torch.Tensor, Dict, List, Tu
def _apply(x: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any:
if torch.is_tensor(x):
return fn(x)
elif isinstance(x, OrderedDict):
od = OrderedDict()
for key, value in x.items():
od[key] = _apply(value)
return od
elif isinstance(x, dict):
return {key: _apply(value) for key, value in x.items()}
elif isinstance(x, list):
......
......@@ -21,8 +21,9 @@
# mypy: disallow_untyped_decorators = False
"""
Collection of some testing utilities for the Fairscale library. Please complement as you see fit, but refrain from ad-hoc test utils
within the different feature sets and relative imports.
Collection of some testing utilities for the Fairscale library. Please complement as
you see fit, but refrain from ad-hoc test utils within the different feature sets and
relative imports.
"""
import functools
......
......@@ -27,15 +27,22 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test
result = dist_init(rank, world_size, tempfile_name, unused)
assert result, "Dist init failed"
my_lr = 0.1
if test_case["assert_ref_out"]:
with torch.no_grad():
# Compute one iteration local output.
weight = model.weight.T.clone().cuda()
v = torch.Tensor(test_case["inputs"][0][rank]).cuda()
ref_out = torch.matmul(v, weight)
ref_forward_output_my_rank = torch.matmul(v, weight)
# Compute one iteration global weight update.
v = torch.Tensor(test_case["inputs"][0][:world_size]).cuda()
grad = v.sum(0).repeat(weight.shape[0], 1).div(world_size)
ref_weight_out = weight - grad.T * my_lr
model.to("cuda")
assert isinstance(fsdp_config, dict), str(fsdp_config)
model = FSDP(model, **fsdp_config)
optim = SGD(model.parameters(), lr=0.1)
optim = SGD(model.parameters(), lr=my_lr)
inputs = test_case["inputs"]
assert len(inputs) == 1 or not test_case["assert_ref_out"]
assert len(inputs[0]) >= world_size
......@@ -45,9 +52,16 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test
out.sum().backward()
optim.step()
optim.zero_grad()
if test_case["assert_ref_out"]:
with model.summon_full_params():
weight_out = model.module.weight.data.T.clone()
# make sure we can do more fwd/bwd
loss = model(in_data)
loss.sum().backward()
if test_case["assert_ref_out"]:
torch.testing.assert_allclose(ref_out, out)
torch.testing.assert_allclose(ref_forward_output_my_rank, out)
torch.testing.assert_allclose(ref_weight_out, weight_out)
model.assert_state(TrainingState.IDLE)
teardown()
......
......@@ -9,6 +9,7 @@
""" Test utility classes from containers.py. """
from collections import OrderedDict
import random
import pytest
......@@ -44,6 +45,9 @@ def test_apply_to_tensors(devices):
data.append({"key1": get_a_tensor(), "key2": {1: get_a_tensor()}, "key3": 3})
data.insert(0, set(["x", get_a_tensor(), get_a_tensor()]))
data.append(([1], get_a_tensor(), (1), [get_a_tensor()], set((1, 2))))
od = OrderedDict()
od["k"] = "value"
data.append(od)
total = 0
......@@ -52,8 +56,10 @@ def test_apply_to_tensors(devices):
total += t.numel()
return t
apply_to_tensors(fn, data)
new_data = apply_to_tensors(fn, data)
assert total == expected, f"{total} vs. {expected}"
for i, v in enumerate(data):
assert type(new_data[i]) == type(v), f"expected type {type(v)} got {type(new_data[i])}"
def test_pack_unpack():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment