Unverified Commit 5be4817d authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[refactor] [fsdp] Modify FSDP API param name to better reflect functionality (#676)

* api changes

* fix list

* modify changelog

* modify changelog

* modify changelog

* move function
parent bbac5564
......@@ -13,7 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- MoE: several fixes [#666] [#667] [#668]
- setup.py: hide CUDA extensions behind `BUILD_CUDA_EXTENSIONS` envvar [#634]
- SDP: re-expose the module property [#647]
- Cleanup - rename and move the `checkpoint_activations` wrapper [#654]
- checkpointing: rename and move the checkpoint_activations wrapper [#654]
- FSDP: Rename API arg `cpu_offload` to `move_params_to_cpu` to better reflect functionality. We will deprecate `cpu_offload` in an upcoming release. [#676]
### Added
- FSDP: added `force_input_to_fp32` flag for SyncBatchNorm [#659]
......
......@@ -146,7 +146,7 @@ class FullyShardedDataParallel(nn.Module):
flatten_parameters (bool, Optional):
if ``True``, flatten parameters into a single contiguous tensor,
which improves training speed.
cpu_offload (bool, Optional):
move_params_to_cpu (bool, Optional):
if ``True``, offload FP32 params to CPU. This is only relevant when
*``mixed_precision``* is ``True``.
compute_dtype (torch.dtype, Optional):
......@@ -215,6 +215,10 @@ class FullyShardedDataParallel(nn.Module):
verbose (bool):
Set this to ``True`` to turn on verbose output for model's string representation.
Default: False
cpu_offload (bool, Optional):
if ``True``, offload FP32 params to CPU. This is only relevant when
*``mixed_precision``* is ``True``. Note: This arg will be deprecated in favor of
*``move_params_to_cpu``* in an upcoming release.
"""
def __init__(
......@@ -225,7 +229,7 @@ class FullyShardedDataParallel(nn.Module):
mixed_precision: bool = False,
fp32_reduce_scatter: bool = False,
flatten_parameters: bool = True,
cpu_offload: bool = False,
move_params_to_cpu: bool = False,
compute_dtype: Optional[torch.dtype] = None,
buffer_dtype: Optional[torch.dtype] = None,
move_grads_to_cpu: Optional[bool] = None,
......@@ -236,6 +240,7 @@ class FullyShardedDataParallel(nn.Module):
clear_autocast_cache: bool = False,
force_input_to_fp32: bool = False,
verbose: bool = False,
cpu_offload: bool = False,
):
init_start = time.time()
super().__init__()
......@@ -246,10 +251,10 @@ class FullyShardedDataParallel(nn.Module):
self.mixed_precision = mixed_precision
self.fp32_reduce_scatter = fp32_reduce_scatter
self.flatten_parameters = flatten_parameters
self.cpu_offload = cpu_offload
self.move_params_to_cpu = move_params_to_cpu or cpu_offload
self.compute_dtype = compute_dtype or (torch.float16 if mixed_precision else torch.float32)
self.buffer_dtype = buffer_dtype or self.compute_dtype
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
self.move_grads_to_cpu = self.move_params_to_cpu if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb
self.compute_device = compute_device or _get_default_cuda_device(module)
self.uncollected_opt_state: Dict[int, Dict] = {}
......@@ -267,7 +272,7 @@ class FullyShardedDataParallel(nn.Module):
if self.fp32_reduce_scatter and not self.mixed_precision:
raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
if self.cpu_offload and not self.mixed_precision:
if self.move_params_to_cpu and not self.mixed_precision:
raise ValueError("cpu_offload requires mixed_precision=True")
# skip validation if the process group was created above
......@@ -570,7 +575,7 @@ class FullyShardedDataParallel(nn.Module):
f"buffer_dtype={self.buffer_dtype}, "
f"fp32_reduce_scatter={self.fp32_reduce_scatter}, "
f"compute_device={self.compute_device}"
f"cpu_offload={self.cpu_offload}, "
f"cpu_offload={self.move_params_to_cpu}, "
f"move_grads_to_cpu={self.move_grads_to_cpu}, "
f"bucket_cap_mb={self.bucket_cap_mb}, "
f"clear_autocast_cache={self.clear_autocast_cache}"
......@@ -658,7 +663,7 @@ class FullyShardedDataParallel(nn.Module):
else:
state_dict = super().state_dict(*args, **kwargs)
if self.cpu_offload:
if self.move_params_to_cpu:
for k in state_dict.keys():
state_dict[k] = state_dict[k].cpu()
......@@ -909,7 +914,7 @@ class FullyShardedDataParallel(nn.Module):
if self.mixed_precision:
assert p._fp32_shard.dtype == torch.float32
if self.cpu_offload:
if self.move_params_to_cpu:
assert p._fp32_shard.device == torch.device("cpu")
# If we plan to keep the FP32 parameters on CPU, then pinning
# memory allows us to later use non-blocking transfers when moving
......@@ -1429,7 +1434,7 @@ class FullyShardedDataParallel(nn.Module):
if not p._is_sharded: # e.g., when world_size == 1
update_p_data()
else:
# If self.cpu_offload and force_full_precision, we need to cast
# If self.move_params_to_cpu and force_full_precision, we need to cast
# the FP32 CPU param to CUDA for the all-gather.
p_data = p.data.to(p._full_param_padded.device)
......@@ -1717,6 +1722,11 @@ class FullyShardedDataParallel(nn.Module):
f"{msg} cur={torch.cuda.memory_allocated()/gb_denom: .4f} GB, max={torch.cuda.max_memory_allocated()/gb_denom: .4f} GB, t={time.time()-self._tstart: .1f}"
)
# Note: This property will be deprecated in an upcoming release in favor of `move_params_to_cpu`.
@property
def cpu_offload(self) -> bool:
return self.move_params_to_cpu
def _get_default_cuda_device(module: nn.Module) -> torch.device:
"""Try to infer CUDA device from module parameters."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment