Unverified Commit ed7ca766 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[fix] Decouple `move_params_to_cpu` from the `mixed_precision`. (#822)

* remove offload dependency on fp16

* update python version for cpu tess

* run CPU tests with updated PyTorch version

* split changes

* revert tests config

* fix lint errors

* update nightly and test PyTorch versions

* skip failing multiprocess pipe test

* always skip test

* always skip test

* always skip test

* lint error

* skip unsupported versions

* improve skip message

* lint errors

* modify docs

* add tests

* fix test failures

* modify comments

* fix lint errors

* fix lint errors
parent b60f3db0
...@@ -180,8 +180,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -180,8 +180,7 @@ class FullyShardedDataParallel(nn.Module):
if ``True``, flatten parameters into a single contiguous tensor, if ``True``, flatten parameters into a single contiguous tensor,
which improves training speed. which improves training speed.
move_params_to_cpu (bool, Optional): move_params_to_cpu (bool, Optional):
if ``True``, offload FP32 params to CPU. This is only relevant when if ``True``, offload params to CPU.
*``mixed_precision``* is ``True``.
compute_dtype (torch.dtype, Optional): compute_dtype (torch.dtype, Optional):
dtype for full parameters for computation. This defaults to dtype for full parameters for computation. This defaults to
``torch.float32`` unless *``mixed_precision``* is set, in which case ``torch.float32`` unless *``mixed_precision``* is set, in which case
...@@ -249,10 +248,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -249,10 +248,8 @@ class FullyShardedDataParallel(nn.Module):
Set this to ``True`` to turn on verbose output for model's string representation. Set this to ``True`` to turn on verbose output for model's string representation.
Default: False Default: False
cpu_offload (bool, Optional): cpu_offload (bool, Optional):
if ``True``, offload FP32 params to CPU. This is only relevant when if ``True``, offload params to CPU. Note: This arg will be deprecated in favor of
*``mixed_precision``* is ``True``. Note: This arg will be deprecated in favor of *``move_params_to_cpu``* in an upcoming release.
*``move_params_to_cpu``* in an upcoming release. Please prefer
specifying ``move_params_to_cpu`` instead.
""" """
def __init__( def __init__(
...@@ -306,8 +303,6 @@ class FullyShardedDataParallel(nn.Module): ...@@ -306,8 +303,6 @@ class FullyShardedDataParallel(nn.Module):
if self.fp32_reduce_scatter and not self.mixed_precision: if self.fp32_reduce_scatter and not self.mixed_precision:
raise ValueError("fp32_reduce_scatter requires mixed_precision=True") raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
if self.move_params_to_cpu and not self.mixed_precision:
raise ValueError("move_params_to_cpu requires mixed_precision=True")
# skip validation if the process group was created above # skip validation if the process group was created above
if process_group: if process_group:
...@@ -415,7 +410,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -415,7 +410,7 @@ class FullyShardedDataParallel(nn.Module):
@property @property
def module(self) -> FlattenParamsWrapper: def module(self) -> FlattenParamsWrapper:
""" make model.module accessible, just like DDP. """ """make model.module accessible, just like DDP."""
assert isinstance(self._fsdp_wrapped_module, FlattenParamsWrapper) assert isinstance(self._fsdp_wrapped_module, FlattenParamsWrapper)
return self._fsdp_wrapped_module return self._fsdp_wrapped_module
...@@ -991,8 +986,9 @@ class FullyShardedDataParallel(nn.Module): ...@@ -991,8 +986,9 @@ class FullyShardedDataParallel(nn.Module):
(typically FP32, but this is dependent on the dtype of the model (typically FP32, but this is dependent on the dtype of the model
as it's passed in by the user). This can be on CPU or GPU as it's passed in by the user). This can be on CPU or GPU
depending on the value of *``move_params_to_cpu``*. depending on the value of *``move_params_to_cpu``*.
``_fp16_shard``: if *``mixed_precision``* is ``True``, this will be ``_fp16_shard``: This will be a single shard of the parameters in FP16, used for all-gather.
a single shard of the parameters in FP16, used for all-gather. This can be in FP16 or FP32 depending on the value of *``compute_dtype``* and
if params are offloaded to CPU.
``_full_param_padded``: the full weight (padded to be evenly ``_full_param_padded``: the full weight (padded to be evenly
divisible by ``world_size``), used for computation in the divisible by ``world_size``), used for computation in the
forward and backward pass. This will be resized in place and forward and backward pass. This will be resized in place and
...@@ -1007,24 +1003,33 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1007,24 +1003,33 @@ class FullyShardedDataParallel(nn.Module):
if self.mixed_precision: if self.mixed_precision:
assert p._fp32_shard.dtype == torch.float32 assert p._fp32_shard.dtype == torch.float32
if self.move_params_to_cpu:
assert p._fp32_shard.device == torch.device("cpu")
if self.move_params_to_cpu: # If we plan to keep the FP32 parameters on CPU, then pinning
assert p._fp32_shard.device == torch.device("cpu") # memory allows us to later use non-blocking transfers when moving
# If we plan to keep the FP32 parameters on CPU, then pinning # the FP32 param shard to compute_device.
# memory allows us to later use non-blocking transfers when moving p._fp32_shard = p._fp32_shard.pin_memory()
# the FP32 param shard to compute_device. p.data = p._fp32_shard
p._fp32_shard = p._fp32_shard.pin_memory()
p.data = p._fp32_shard if self.move_params_to_cpu or self.mixed_precision:
# In mixed precision mode, we maintain a reduced precision # In mixed precision mode, we maintain a reduced precision
# (typically FP16) parameter shard on compute_device for performing # (typically FP16) parameter shard on compute_device for performing
# the computation in the forward/backward pass. We resize the # the computation in the forward/backward pass. We resize the
# storage to size 0 at init (here) and re-materialize (by copying # storage to size 0 at init (here) and re-materialize (by copying
# from _fp32_shard) as needed. # from _fp32_shard) as needed. If offloading params to CPU, the
# dtype of the fp16 shard will depend on the *`compute_dtype`*.
p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype) p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype)
free_storage_(p._fp16_shard) free_storage_(p._fp16_shard)
else:
p._fp16_shard = None # use _fp32_shard if self.mixed_precision:
assert p._fp32_shard.dtype == torch.float32
if not self.mixed_precision and not self.move_params_to_cpu:
# use _fp32_shard if you are not in using mixed precision or
# offloading params and grads to CPU.
p._fp16_shard = None
# We also maintain a full-sized parameter of type self.compute_dtype # We also maintain a full-sized parameter of type self.compute_dtype
# (FP16 for mixed_precision or FP32 otherwise). We resize the # (FP16 for mixed_precision or FP32 otherwise). We resize the
...@@ -1125,7 +1130,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1125,7 +1130,7 @@ class FullyShardedDataParallel(nn.Module):
""" """
if not torch.cuda.is_available(): if not torch.cuda.is_available():
return return
if self.mixed_precision: if self.mixed_precision or self.move_params_to_cpu:
self._streams["fp32_to_fp16"].wait_stream(torch.cuda.current_stream()) self._streams["fp32_to_fp16"].wait_stream(torch.cuda.current_stream())
else: else:
self._streams["all_gather"].wait_stream(torch.cuda.current_stream()) self._streams["all_gather"].wait_stream(torch.cuda.current_stream())
...@@ -1159,7 +1164,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1159,7 +1164,7 @@ class FullyShardedDataParallel(nn.Module):
if self.reshard_after_forward: if self.reshard_after_forward:
self._free_full_params() self._free_full_params()
if self.mixed_precision: if self.mixed_precision or self.move_params_to_cpu:
self._free_fp16_param_shard() self._free_fp16_param_shard()
# Switch to main FP32 param shard. We maintain this invariant throughout # Switch to main FP32 param shard. We maintain this invariant throughout
...@@ -1605,7 +1610,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1605,7 +1610,7 @@ class FullyShardedDataParallel(nn.Module):
p.data = custom_output_tensor p.data = custom_output_tensor
output_tensors.append((p.data, True)) output_tensors.append((p.data, True))
elif not p._is_sharded: elif not p._is_sharded:
if self.mixed_precision and not force_full_precision: if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision:
assert p._fp16_shard is not None assert p._fp16_shard is not None
p.data = p._fp16_shard p.data = p._fp16_shard
output_tensors.append((p.data, True)) output_tensors.append((p.data, True))
...@@ -1627,9 +1632,20 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1627,9 +1632,20 @@ class FullyShardedDataParallel(nn.Module):
self.has_full_params = True self.has_full_params = True
with torch.cuda.stream(self._streams["all_gather"]): with torch.cuda.stream(self._streams["all_gather"]):
if self.mixed_precision and not force_full_precision: if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision:
self._cast_fp32_param_shards_to_fp16() self._cast_fp32_param_shards_to_fp16()
if self.move_params_to_cpu:
if force_full_precision:
# If the compute_dtype and storage dtype are the same,
# use pinned memory. Otherwise move p.data to the compute
# device.
if self.params[0].dtype == self.compute_dtype:
self._cast_fp32_param_shards_to_fp16()
else:
for p in self.params:
p.data = p.data.to(self.compute_device)
for p in self.params: for p in self.params:
if not p._is_sharded: # e.g., when world_size == 1 if not p._is_sharded: # e.g., when world_size == 1
update_p_data() update_p_data()
...@@ -1661,8 +1677,12 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1661,8 +1677,12 @@ class FullyShardedDataParallel(nn.Module):
# Set p.data = output_tensor (with padding trimmed) # Set p.data = output_tensor (with padding trimmed)
update_p_data(output_tensor) update_p_data(output_tensor)
if self.mixed_precision and not force_full_precision: if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision:
self._free_fp16_param_shard([p])
if self.move_params_to_cpu and (self.params[0].dtype == self.compute_dtype):
self._free_fp16_param_shard([p]) self._free_fp16_param_shard([p])
torch.cuda.current_stream().wait_stream(self._streams["all_gather"]) torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
return output_tensors return output_tensors
...@@ -1679,7 +1699,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1679,7 +1699,7 @@ class FullyShardedDataParallel(nn.Module):
assert self.has_full_params assert self.has_full_params
for p in self.params: for p in self.params:
if not p._is_sharded: if not p._is_sharded:
if self.mixed_precision: if self.mixed_precision or self.move_params_to_cpu:
assert p._fp16_shard is not None assert p._fp16_shard is not None
assert p._fp16_shard.storage().size() != 0 assert p._fp16_shard.storage().size() != 0
p.data = p._fp16_shard p.data = p._fp16_shard
...@@ -1689,8 +1709,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1689,8 +1709,8 @@ class FullyShardedDataParallel(nn.Module):
@torch.no_grad() @torch.no_grad()
def _prep_grads_for_backward(self) -> None: def _prep_grads_for_backward(self) -> None:
""" Make sure p.grad is correctly prepared for the backward with """Make sure p.grad is correctly prepared for the backward with
right shape, device, accumulated values, etc. right shape, device, accumulated values, etc.
""" """
for p in self.params: for p in self.params:
if p.grad is not None: if p.grad is not None:
...@@ -1718,7 +1738,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1718,7 +1738,7 @@ class FullyShardedDataParallel(nn.Module):
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
for p in params: for p in params:
if not p._is_sharded: # e.g., world_size == 1 if not p._is_sharded: # e.g., world_size == 1
if self.mixed_precision: if self.mixed_precision or self.move_params_to_cpu:
self._free_fp16_param_shard([p]) self._free_fp16_param_shard([p])
continue continue
# Don't let PyTorch reuse this memory until all work in the current # Don't let PyTorch reuse this memory until all work in the current
...@@ -2167,7 +2187,7 @@ def _pre_load_state_dict_hook( ...@@ -2167,7 +2187,7 @@ def _pre_load_state_dict_hook(
def _clean_path(path: str) -> str: def _clean_path(path: str) -> str:
""" Remove FSDP related wrapper modules from a given state dict key str path. """ """Remove FSDP related wrapper modules from a given state dict key str path."""
return ".".join([split for split in path.split(".") if split not in {"_fsdp_wrapped_module", "_fpw_module"}]) return ".".join([split for split in path.split(".") if split not in {"_fsdp_wrapped_module", "_fpw_module"}])
......
...@@ -65,6 +65,9 @@ class DistributedTest(unittest.TestCase): ...@@ -65,6 +65,9 @@ class DistributedTest(unittest.TestCase):
model.clip_grad_norm_(clip_norm, norm_type) model.clip_grad_norm_(clip_norm, norm_type)
else: else:
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm, norm_type) torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm, norm_type)
params = [p for p in model.parameters()]
print(f"params.device {params[0].device} param.grad.device {params[0].grad.device}")
optim.step() optim.step()
if isinstance(model, FullyShardedDataParallel): if isinstance(model, FullyShardedDataParallel):
model.assert_state(TrainingState.IDLE) model.assert_state(TrainingState.IDLE)
...@@ -302,6 +305,15 @@ class TestComparisonToPyTorchDDP(DistributedTest): ...@@ -302,6 +305,15 @@ class TestComparisonToPyTorchDDP(DistributedTest):
) )
spawn_and_init(test_fn) spawn_and_init(test_fn)
def test_cpu_offload_and_cpu_grads_no_mixed_precision(self):
# We don't test the False condition because that requires the optimizer to internally do
# the device transfer and PyTorch optimizers don't support this.
config = {"mixed_precision": False, "cpu_offload": True, "move_grads_to_cpu": True}
test_fn = functools.partial(
self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False, lr=0.01
)
spawn_and_init(test_fn)
def test_cpu_offload_and_cuda_grads_breaks(self): def test_cpu_offload_and_cuda_grads_breaks(self):
# If grads are on gpu, but model and optimizer are on cpu, backward breaks. # If grads are on gpu, but model and optimizer are on cpu, backward breaks.
config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": False} config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": False}
...@@ -403,14 +415,14 @@ class TestParamInit(DistributedTest): ...@@ -403,14 +415,14 @@ class TestParamInit(DistributedTest):
class TestSerialization(DistributedTest): class TestSerialization(DistributedTest):
@parameterized.expand([[False, False], [True, False], [True, True]], name_func=rename_test) @parameterized.expand([[False, False], [True, False], [True, True], [False, True]], name_func=rename_test)
def test_pickle(self, mixed_precision, cpu_offload): def test_pickle(self, mixed_precision, cpu_offload):
"""Ensure that wrapped modules can be pickled/unpickled.""" """Ensure that wrapped modules can be pickled/unpickled."""
config = {"mixed_precision": mixed_precision, "cpu_offload": cpu_offload} config = {"mixed_precision": mixed_precision, "cpu_offload": cpu_offload}
test_fn = functools.partial(self._test_pickle, config=config) test_fn = functools.partial(self._test_pickle, config=config)
spawn_and_init(test_fn, world_sizes=[2]) spawn_and_init(test_fn, world_sizes=[2])
@parameterized.expand([[False, False], [True, False], [True, True]], name_func=rename_test) @parameterized.expand([[False, False], [True, False], [True, True], [False, True]], name_func=rename_test)
def test_multiprocessing(self, mixed_precision, cpu_offload): def test_multiprocessing(self, mixed_precision, cpu_offload):
"""Ensure that wrapped modules can be sent via multiprocessing.""" """Ensure that wrapped modules can be sent via multiprocessing."""
config = {"mixed_precision": mixed_precision, "cpu_offload": cpu_offload} config = {"mixed_precision": mixed_precision, "cpu_offload": cpu_offload}
......
...@@ -186,7 +186,7 @@ class TestStateDictDeviceDtype(DistributedTest): ...@@ -186,7 +186,7 @@ class TestStateDictDeviceDtype(DistributedTest):
) )
spawn_and_init(test_fn) spawn_and_init(test_fn)
@parameterized.expand([[False, False], [True, False], [True, True]], name_func=rename_test) @parameterized.expand([[False, False], [True, False], [True, True], [False, True]], name_func=rename_test)
def test_state_dict_device_cuda(self, mixed_precision, cpu_offload): def test_state_dict_device_cuda(self, mixed_precision, cpu_offload):
test_fn = functools.partial( test_fn = functools.partial(
self._test_state_dict_device, self._test_state_dict_device,
...@@ -194,7 +194,7 @@ class TestStateDictDeviceDtype(DistributedTest): ...@@ -194,7 +194,7 @@ class TestStateDictDeviceDtype(DistributedTest):
) )
spawn_and_init(test_fn) spawn_and_init(test_fn)
@parameterized.expand([[False, False], [True, False], [True, True]], name_func=rename_test) @parameterized.expand([[False, False], [True, False], [True, True], [False, True]], name_func=rename_test)
def test_state_dict_device_cpu(self, mixed_precision, cpu_offload): def test_state_dict_device_cpu(self, mixed_precision, cpu_offload):
test_fn = functools.partial( test_fn = functools.partial(
self._test_state_dict_device, self._test_state_dict_device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment