Unverified Commit fa1b85fb authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

[fix][FSDP] fix weight init when using apply() (fixes #490 and #444) (#543)

* Add new test for weight init (fails)
* Set FSDP.compute_device so summon_full_params works before module moves to CUDA
* Override FSDP.apply to enable custom weight init
parent e3865549
......@@ -9,7 +9,7 @@ from enum import Enum, auto
import functools
from math import inf
import traceback
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, NamedTuple, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, NamedTuple, Optional, Set, Tuple, Union
import torch
from torch.autograd import Variable
......@@ -150,6 +150,11 @@ class FullyShardedDataParallel(nn.Module):
based on world_size, so the max shard size is roughly
``bucket_cap_mb / world_size``. Values <= 0 disable bucketing.
Default: 25.
compute_device (torch.device, Optional):
device for computation. If not given and module params are on a CUDA
device, the param's device will be used. If not given and module
params are on CPU, then the current CUDA device (as indicated by
``torch.cuda.current_device()`` will be used.
"""
def __init__(
......@@ -165,6 +170,7 @@ class FullyShardedDataParallel(nn.Module):
buffer_dtype: Optional[torch.dtype] = None,
move_grads_to_cpu: Optional[bool] = None,
bucket_cap_mb: int = 25,
compute_device: Optional[torch.device] = None,
):
super().__init__()
self.process_group = process_group or dist.new_group()
......@@ -179,14 +185,21 @@ class FullyShardedDataParallel(nn.Module):
self.buffer_dtype = buffer_dtype or self.compute_dtype
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb
self.compute_device = compute_device
if self.fp32_reduce_scatter and not self.mixed_precision:
raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
if self.cpu_offload and not self.mixed_precision:
raise ValueError("cpu_offload requires mixed_precision=True")
compute_device = torch.device("cuda") if self.cpu_offload else next(module.parameters()).device
validate_process_group(compute_device, self.process_group)
if self.compute_device is None:
# Try to infer CUDA device from module parameters.
self.compute_device = next(module.parameters()).device
if self.compute_device.type != "cuda":
# Fall back to current CUDA device.
self.compute_device = torch.device("cuda")
validate_process_group(self.compute_device, self.process_group)
enable_pytorch_sync_bn(module)
# Only handle params which are not already sharded. This enables
......@@ -239,11 +252,68 @@ class FullyShardedDataParallel(nn.Module):
def module(self) -> nn.Module:
return self._fsdp_wrapped_module # note: may be a FlattenParamsWrapper instance
@torch.no_grad()
def _all_buffers_to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
"""Move all buffers to the specified device and dtype, recursively."""
cast_fn = functools.partial(cast_buffers_, device=device, dtype=dtype)
self.apply(cast_fn)
def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
"""
Applies ``fn`` recursively to every submodule (as returned by
``.children()``) as well as self. Typical use includes initializing the
parameters of a model.
Compared to ``torch.nn.Module.apply``, this version additionally gathers
the full parameters before applying ``fn``. It should not be called from
within another ``summon_full_params`` context.
Args:
fn (nn.Module): function to be applied to each submodule
Returns:
Module: self
"""
is_uninitialized = self._is_root is None
self.assert_state(TrainingState.IDLE)
with self.summon_full_params(recurse=False):
return_value = super().apply(fn)
# summon_full_params will call _lazy_init, which sets _is_root. However,
# apply() may be called directly on children instances to do weight
# init, so we should reset the _is_root flag in this case.
if is_uninitialized and self._is_root:
for module in self.modules():
if isinstance(module, FullyShardedDataParallel):
module._reset_lazy_init()
return return_value
def _cast_buffers(
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, memo: Optional[Set] = None
) -> None:
"""Move all buffers to the given *device* and *dtype*.
If *device* or *dtype* are not given, then they will default to
``self.compute_device`` and ``self.buffer_dtype``, respectively. In the
case of nested FSDP instances, we will respect the child instance's
``compute_device`` and ``buffer_dtype`` configuration.
Args:
device (torch.device, Optional):
device to cast buffers to (defaults to compute_device)
dtype (torch.dtype, Optional):
dtype to cast buffers to (defaults to buffer_dtype)
memo (Set, Optional):
set of modules that have already been processed
"""
if memo is None:
memo = set()
for module in self.modules():
if module is not self and isinstance(module, FullyShardedDataParallel):
# Allow any child FSDP instances to handle their own buffers.
module._cast_buffers(device=device, dtype=dtype, memo=memo)
elif module not in memo:
memo.add(module)
for name, buf in module.named_buffers(recurse=False):
if buf is None:
continue
buf = buf.to(device=device or self.compute_device)
if torch.is_floating_point(buf):
buf = buf.to(dtype=dtype or self.buffer_dtype)
setattr(module, name, buf)
@property
def params_with_grad(self) -> List[Parameter]:
......@@ -386,7 +456,10 @@ class FullyShardedDataParallel(nn.Module):
f"flatten_parameters={self.flatten_parameters}, "
f"cpu_offload={self.cpu_offload}, "
f"compute_dtype={self.compute_dtype}, "
f"move_grads_to_cpu={self.move_grads_to_cpu}"
f"buffer_dtype={self.buffer_dtype}, "
f"move_grads_to_cpu={self.move_grads_to_cpu}, "
f"bucket_cap_mb={self.bucket_cap_mb}, "
f"compute_device={self.compute_device}"
)
def __getattr__(self, name: str) -> Any:
......@@ -443,7 +516,7 @@ class FullyShardedDataParallel(nn.Module):
self._lazy_init()
if self.mixed_precision:
# Buffers dtype stays consistent with parameters.
self._all_buffers_to(dtype=torch.float32)
self._cast_buffers(dtype=torch.float32)
if self._return_full_state_dict:
if self.training_state != TrainingState.SUMMON_FULL_PARAMS:
......@@ -463,8 +536,8 @@ class FullyShardedDataParallel(nn.Module):
state_dict[k] = state_dict[k].cpu()
if self.mixed_precision:
# In case we are in mixed precision, restore buffers back to fp16.
self._all_buffers_to(dtype=self.buffer_dtype)
# In case we are in mixed precision, restore buffers back to buffer_dtype.
self._cast_buffers()
return state_dict
# TODO (Min): figuring out how to do typing for this overloaded function.
......@@ -572,7 +645,7 @@ class FullyShardedDataParallel(nn.Module):
recurse (bool, Optional): recursively summon all params for nested
FSDP instances (default: True)
volatile (bool, Optional): if ``True``, modifications to params are
not guaranteed persist after the context manager exists;
not guaranteed to persist after the context manager exists;
enabling this can be slightly more efficient (default: False)
"""
if recurse:
......@@ -625,6 +698,9 @@ class FullyShardedDataParallel(nn.Module):
self._queue_wait_for_post_backward_closure: Optional[Callable] = None
self._streams: Dict[str, torch.cuda.Stream] = {}
self._reducer: Optional[ReduceScatterBucketer] = None
for p in self.params:
if hasattr(p, "_fp32_shard"):
del p._fp32_shard # reset _init_param_attributes
def _lazy_init(self) -> None:
"""Initialization steps that should happen lazily, typically right
......@@ -642,12 +718,11 @@ class FullyShardedDataParallel(nn.Module):
self._set_is_root()
self._setup_streams()
if self.cpu_offload: # Buffers stay on GPU, and don't get sharded
self._all_buffers_to(device=torch.device("cuda"), dtype=self.buffer_dtype)
else:
self._all_buffers_to(dtype=self.buffer_dtype)
if self._is_root:
# Buffers stay on GPU, and don't get sharded. Since _cast_buffers
# applies recursively, we only call this from the root instance.
self._cast_buffers()
# Don't free the full params for the outer-most (root) instance,
# since those params will be needed immediately after for the
# backward pass.
......@@ -684,10 +759,6 @@ class FullyShardedDataParallel(nn.Module):
if hasattr(p, "_fp32_shard"):
return
# Compute device defaults to CUDA when *cpu_offload* is enabled, or the
# param's current device otherwise (could be CPU).
compute_device = torch.device("cuda") if self.cpu_offload else p.device
# A single shard of the parameters in full precision.
p._fp32_shard = p.data
......@@ -707,7 +778,7 @@ class FullyShardedDataParallel(nn.Module):
# the computation in the forward/backward pass. We resize the
# storage to size 0 at init (here) and re-materialize (by copying
# from _fp32_shard) as needed.
p._fp16_shard = torch.zeros_like(p._fp32_shard, device=compute_device, dtype=self.compute_dtype)
p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype)
free_storage_(p._fp16_shard)
else:
p._fp16_shard = None # use _fp32_shard
......@@ -720,7 +791,7 @@ class FullyShardedDataParallel(nn.Module):
# relevant computation.
if p._is_sharded:
p._full_param_padded = torch.zeros(
p.data.numel() * self.world_size, device=compute_device, dtype=self.compute_dtype
p.data.numel() * self.world_size, device=self.compute_device, dtype=self.compute_dtype
)
free_storage_(p._full_param_padded)
......@@ -1290,20 +1361,6 @@ def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
return apply_to_tensors(fn, args), apply_to_tensors(fn, kwargs)
def cast_buffers_(
module: nn.Module, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
) -> None:
"""Cast all of module.named_buffers to device and floating point buffers to dtype."""
# if buffers are already on the right device and/or dtype this is just python loop cost
assert dtype in {torch.float32, torch.float16} # assumes compute_dtype == float16
for key, buf in module.named_buffers(recurse=False):
if buf is not None:
buf = buf.to(device=device)
if torch.is_floating_point(buf):
buf = buf.to(dtype=dtype)
setattr(module, key, buf)
def free_storage_(data: torch.Tensor) -> None:
"""Free underlying storage of a Tensor."""
if data.storage().size() > 0:
......
......@@ -33,3 +33,4 @@ tests/nn/pipe/test_deferred_batch_norm.py
tests/nn/pipe/test_dependency.py
tests/nn/pipe/test_stream.py
tests/experimental/nn/test_multiprocess_pipe.py
tests/nn/data_parallel/test_fsdp_apply.py
......@@ -77,6 +77,49 @@ class DistributedTest(unittest.TestCase):
model = FullyShardedDataParallel(TransformerWithSharedParams(group, **model_kwargs), group, **config).cuda()
return model
@classmethod
def _test_identical_outputs(
cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2,
):
if config.get("mixed_precision", False):
autocast = True
# Force the compute dtype to be torch.float32 so that we get
# identical results as PyTorch DDP when using autocast. Note that
# this will cause the all-gather to happen in FP32, which is slower
# than necessary in most cases.
config["compute_dtype"] = torch.float32
else:
autocast = False
# Establish reference behavior with PyTorch DDP (+ optionally autocast).
model = model_init_fn(group=group, wrapper_config=None).cuda()
if ref_ddp_fn is None:
model = nn.parallel.DistributedDataParallel(
model, device_ids=[rank], output_device=rank, process_group=group
)
else:
model = ref_ddp_fn(model, group)
ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
ref_state_dict = model.module.state_dict()
if config.get("cpu_offload", False):
for k in ref_state_dict.keys():
ref_state_dict[k] = ref_state_dict[k].cpu()
# Confirm we get the same behavior using FullyShardedDataParallel.
model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
if use_cuda:
model = model.cuda()
else:
assert next(model.parameters()).device == torch.device("cpu")
shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
shard_state_dict = model.state_dict()
try:
torch.testing.assert_allclose(ref_loss, shard_loss)
assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True)
except (AssertionError, RuntimeError) as e:
raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}")
class TestMixedPrecision(DistributedTest):
def test_all_fp32(self):
......@@ -313,49 +356,6 @@ class TestComparisonToPyTorchDDP(DistributedTest):
def _dummy_ddp_fn(self, model, group):
return DummyDDP(model)
@classmethod
def _test_identical_outputs(
cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2,
):
if config.get("mixed_precision", False):
autocast = True
# Force the compute dtype to be torch.float32 so that we get
# identical results as PyTorch DDP when using autocast. Note that
# this will cause the all-gather to happen in FP32, which is slower
# than necessary in most cases.
config["compute_dtype"] = torch.float32
else:
autocast = False
# Establish reference behavior with PyTorch DDP (+ optionally autocast).
model = model_init_fn(group=group, wrapper_config=None).cuda()
if ref_ddp_fn is None:
model = nn.parallel.DistributedDataParallel(
model, device_ids=[rank], output_device=rank, process_group=group
)
else:
model = ref_ddp_fn(model, group)
ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
ref_state_dict = model.module.state_dict()
if config.get("cpu_offload", False):
for k in ref_state_dict.keys():
ref_state_dict[k] = ref_state_dict[k].cpu()
# Confirm we get the same behavior using FullyShardedDataParallel.
model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config)
if use_cuda:
model = model.cuda()
else:
assert next(model.parameters()).device == torch.device("cpu")
shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
shard_state_dict = model.state_dict()
try:
torch.testing.assert_allclose(ref_loss, shard_loss)
assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True)
except (AssertionError, RuntimeError) as e:
raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}")
@parameterized.expand([[1], [inf]], name_func=rename_test)
def test_clip_norm_transformer(self, norm_type):
config = {"mixed_precision": True}
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
import unittest
from parameterized import parameterized
import torch.nn as nn
from .test_fsdp import (
CONFIG_OPTIONS,
DistributedTest,
NestedWrappedModule,
TransformerWithSharedParams,
rename_test,
spawn_and_init,
)
class TestApply(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_transformer_weight_init(self, config):
model_init_fn = functools.partial(model_init_and_apply_custom_weight_init, TransformerWithSharedParams)
test_fn = functools.partial(self._test_identical_outputs, model_init_fn, config, lr=0.01)
spawn_and_init(test_fn)
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_nested_wrapped_weight_init(self, config):
model_init_fn = functools.partial(model_init_and_apply_custom_weight_init, NestedWrappedModule)
test_fn = functools.partial(self._test_identical_outputs, model_init_fn, config, lr=0.01)
spawn_and_init(test_fn)
def model_init_and_apply_custom_weight_init(model_init_fn, *args, **kwargs):
model = model_init_fn(*args, **kwargs)
model.apply(init_bert_params_)
return model
def init_bert_params_(module):
"""
Initialize the weights specific to the BERT Model.
"""
def normal_(data):
# with FSDP, module params will be on CUDA, so we cast them back to CPU
# so that the RNG is consistent with and without FSDP
data.copy_(data.cpu().normal_(mean=0.0, std=0.02))
if isinstance(module, nn.Linear):
normal_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, nn.MultiheadAttention):
normal_(module.in_proj_weight.data)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment