Unverified Commit 77f92b38 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[feat]: add summon_full_params context mgr (#433)

* [feat]: add summon_full_params context mgr

* fix

* fix

* addressed comments

* fixed the state_dict copy

* lint
parent f7813d6d
...@@ -51,6 +51,7 @@ class TrainingState(Enum): ...@@ -51,6 +51,7 @@ class TrainingState(Enum):
IDLE = auto() IDLE = auto()
FORWARD = auto() FORWARD = auto()
BACKWARD = auto() BACKWARD = auto()
SUMMON_FULL_PARAMS = auto()
class FullyShardedDataParallel(nn.Module): class FullyShardedDataParallel(nn.Module):
...@@ -383,16 +384,17 @@ class FullyShardedDataParallel(nn.Module): ...@@ -383,16 +384,17 @@ class FullyShardedDataParallel(nn.Module):
.. warning:: This needs to be called on all ranks, since synchronization .. warning:: This needs to be called on all ranks, since synchronization
primitives will be used. primitives will be used.
""" """
torch.cuda.synchronize() with self.summon_full_params():
self._lazy_init() # Buffers dtype stays consistent with parameters.
self._rebuild_full_params() self._all_buffers_to(dtype=torch.float32)
self._all_buffers_to(dtype=torch.float32) # Buffers dtype stays consistent with parameters.
state_dict = self.module.state_dict(*args, **kwargs) state_dict = self.module.state_dict(*args, **kwargs)
# We don't free the params after generating the state dict, since # We copy the state_dict since full param will be freed after
# freeing is done in-place (via the Storage) and would corrupt the # we exit the summon_full_params() context.
# returned state dict. However, we need to maintain the invariant that for key in state_dict.keys():
# p.data corresponds to the FP32 param shard, so we do that here. state_dict[key] = state_dict[key].clone()
self._use_fp32_param_shard()
# In case we are in mixed precision, restore buffers back to fp16.
self._all_buffers_to(dtype=self.compute_dtype) self._all_buffers_to(dtype=self.compute_dtype)
return state_dict return state_dict
...@@ -419,11 +421,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -419,11 +421,8 @@ class FullyShardedDataParallel(nn.Module):
.. warning:: This needs to be called on all ranks, since synchronization .. warning:: This needs to be called on all ranks, since synchronization
primitives will be used. primitives will be used.
""" """
torch.cuda.synchronize() with self.summon_full_params():
self._lazy_init()
self._rebuild_full_params()
output = self.module.load_state_dict(state_dict, strict) output = self.module.load_state_dict(state_dict, strict)
self._free_full_params()
return output return output
def load_local_state_dict( def load_local_state_dict(
...@@ -457,6 +456,32 @@ class FullyShardedDataParallel(nn.Module): ...@@ -457,6 +456,32 @@ class FullyShardedDataParallel(nn.Module):
for m, old_flag in old_flags: for m, old_flag in old_flags:
m.require_backward_grad_sync = old_flag m.require_backward_grad_sync = old_flag
@contextlib.contextmanager
def summon_full_params(self) -> Generator:
"""
A context manager to expose full params for the underlying model.
Can be useful *after* forward/backward for a model to get the params
for additional processing or checking.
This can be used on inner FSDPs.
This can *not* be used within a forward or backward pass. Nor can forward
and backward be started from within this context.
"""
torch.cuda.synchronize()
self._lazy_init()
self.assert_state(TrainingState.IDLE)
# Set the state so that we assert when trying to go into
# forward/backward.
self.training_state = TrainingState.SUMMON_FULL_PARAMS
self._rebuild_full_params()
try:
yield
finally:
self._free_full_params()
self._use_fp32_param_shard()
self.training_state = TrainingState.IDLE
def _reset_lazy_init(self) -> None: def _reset_lazy_init(self) -> None:
"""Reset instance so :func:`_lazy_init` will run on the next forward.""" """Reset instance so :func:`_lazy_init` will run on the next forward."""
self._is_root: Optional[bool] = None self._is_root: Optional[bool] = None
...@@ -815,8 +840,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -815,8 +840,11 @@ class FullyShardedDataParallel(nn.Module):
if self.move_grads_to_cpu: if self.move_grads_to_cpu:
# Wait for the non-blocking GPU -> CPU grad transfers to finish. # Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch.cuda.current_stream().synchronize() torch.cuda.current_stream().synchronize()
# A backward pass is done. # A backward pass is done, update root and nested FSDP's flags.
self.training_state = TrainingState.IDLE for m in self.modules(): # includes self
if isinstance(m, FullyShardedDataParallel):
m.assert_state(TrainingState.BACKWARD)
m.training_state = TrainingState.IDLE
@torch.no_grad() @torch.no_grad()
def _rebuild_full_params(self) -> None: def _rebuild_full_params(self) -> None:
...@@ -851,6 +879,10 @@ class FullyShardedDataParallel(nn.Module): ...@@ -851,6 +879,10 @@ class FullyShardedDataParallel(nn.Module):
@torch.no_grad() @torch.no_grad()
def _use_full_params(self) -> None: def _use_full_params(self) -> None:
"""Switching p.data pointers to use the full params.
Note: this is used assuming full param gathering is already done.
"""
for p in self.params: for p in self.params:
if not p._is_sharded: if not p._is_sharded:
if self.mixed_precision: if self.mixed_precision:
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch import torch
...@@ -16,6 +17,11 @@ def apply_to_tensors(fn: Callable, container: Union[torch.Tensor, Dict, List, Tu ...@@ -16,6 +17,11 @@ def apply_to_tensors(fn: Callable, container: Union[torch.Tensor, Dict, List, Tu
def _apply(x: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any: def _apply(x: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any:
if torch.is_tensor(x): if torch.is_tensor(x):
return fn(x) return fn(x)
elif isinstance(x, OrderedDict):
od = OrderedDict()
for key, value in x.items():
od[key] = _apply(value)
return od
elif isinstance(x, dict): elif isinstance(x, dict):
return {key: _apply(value) for key, value in x.items()} return {key: _apply(value) for key, value in x.items()}
elif isinstance(x, list): elif isinstance(x, list):
......
...@@ -21,8 +21,9 @@ ...@@ -21,8 +21,9 @@
# mypy: disallow_untyped_decorators = False # mypy: disallow_untyped_decorators = False
""" """
Collection of some testing utilities for the Fairscale library. Please complement as you see fit, but refrain from ad-hoc test utils Collection of some testing utilities for the Fairscale library. Please complement as
within the different feature sets and relative imports. you see fit, but refrain from ad-hoc test utils within the different feature sets and
relative imports.
""" """
import functools import functools
......
...@@ -27,15 +27,22 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test ...@@ -27,15 +27,22 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test
result = dist_init(rank, world_size, tempfile_name, unused) result = dist_init(rank, world_size, tempfile_name, unused)
assert result, "Dist init failed" assert result, "Dist init failed"
my_lr = 0.1
if test_case["assert_ref_out"]: if test_case["assert_ref_out"]:
with torch.no_grad(): with torch.no_grad():
# Compute one iteration local output.
weight = model.weight.T.clone().cuda() weight = model.weight.T.clone().cuda()
v = torch.Tensor(test_case["inputs"][0][rank]).cuda() v = torch.Tensor(test_case["inputs"][0][rank]).cuda()
ref_out = torch.matmul(v, weight) ref_forward_output_my_rank = torch.matmul(v, weight)
# Compute one iteration global weight update.
v = torch.Tensor(test_case["inputs"][0][:world_size]).cuda()
grad = v.sum(0).repeat(weight.shape[0], 1).div(world_size)
ref_weight_out = weight - grad.T * my_lr
model.to("cuda") model.to("cuda")
assert isinstance(fsdp_config, dict), str(fsdp_config) assert isinstance(fsdp_config, dict), str(fsdp_config)
model = FSDP(model, **fsdp_config) model = FSDP(model, **fsdp_config)
optim = SGD(model.parameters(), lr=0.1) optim = SGD(model.parameters(), lr=my_lr)
inputs = test_case["inputs"] inputs = test_case["inputs"]
assert len(inputs) == 1 or not test_case["assert_ref_out"] assert len(inputs) == 1 or not test_case["assert_ref_out"]
assert len(inputs[0]) >= world_size assert len(inputs[0]) >= world_size
...@@ -45,9 +52,16 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test ...@@ -45,9 +52,16 @@ def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test
out.sum().backward() out.sum().backward()
optim.step() optim.step()
optim.zero_grad() optim.zero_grad()
if test_case["assert_ref_out"]:
with model.summon_full_params():
weight_out = model.module.weight.data.T.clone()
# make sure we can do more fwd/bwd
loss = model(in_data)
loss.sum().backward()
if test_case["assert_ref_out"]: if test_case["assert_ref_out"]:
torch.testing.assert_allclose(ref_out, out) torch.testing.assert_allclose(ref_forward_output_my_rank, out)
torch.testing.assert_allclose(ref_weight_out, weight_out)
model.assert_state(TrainingState.IDLE) model.assert_state(TrainingState.IDLE)
teardown() teardown()
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
""" Test utility classes from containers.py. """ """ Test utility classes from containers.py. """
from collections import OrderedDict
import random import random
import pytest import pytest
...@@ -44,6 +45,9 @@ def test_apply_to_tensors(devices): ...@@ -44,6 +45,9 @@ def test_apply_to_tensors(devices):
data.append({"key1": get_a_tensor(), "key2": {1: get_a_tensor()}, "key3": 3}) data.append({"key1": get_a_tensor(), "key2": {1: get_a_tensor()}, "key3": 3})
data.insert(0, set(["x", get_a_tensor(), get_a_tensor()])) data.insert(0, set(["x", get_a_tensor(), get_a_tensor()]))
data.append(([1], get_a_tensor(), (1), [get_a_tensor()], set((1, 2)))) data.append(([1], get_a_tensor(), (1), [get_a_tensor()], set((1, 2))))
od = OrderedDict()
od["k"] = "value"
data.append(od)
total = 0 total = 0
...@@ -52,8 +56,10 @@ def test_apply_to_tensors(devices): ...@@ -52,8 +56,10 @@ def test_apply_to_tensors(devices):
total += t.numel() total += t.numel()
return t return t
apply_to_tensors(fn, data) new_data = apply_to_tensors(fn, data)
assert total == expected, f"{total} vs. {expected}" assert total == expected, f"{total} vs. {expected}"
for i, v in enumerate(data):
assert type(new_data[i]) == type(v), f"expected type {type(v)} got {type(new_data[i])}"
def test_pack_unpack(): def test_pack_unpack():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment