Unverified Commit 5be4817d authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[refactor] [fsdp] Modify FSDP API param name to better reflect functionality (#676)

* api changes

* fix list

* modify changelog

* modify changelog

* modify changelog

* move function
parent bbac5564
...@@ -13,7 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -13,7 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- MoE: several fixes [#666] [#667] [#668] - MoE: several fixes [#666] [#667] [#668]
- setup.py: hide CUDA extensions behind `BUILD_CUDA_EXTENSIONS` envvar [#634] - setup.py: hide CUDA extensions behind `BUILD_CUDA_EXTENSIONS` envvar [#634]
- SDP: re-expose the module property [#647] - SDP: re-expose the module property [#647]
- Cleanup - rename and move the `checkpoint_activations` wrapper [#654] - checkpointing: rename and move the checkpoint_activations wrapper [#654]
- FSDP: Rename API arg `cpu_offload` to `move_params_to_cpu` to better reflect functionality. We will deprecate `cpu_offload` in an upcoming release. [#676]
### Added ### Added
- FSDP: added `force_input_to_fp32` flag for SyncBatchNorm [#659] - FSDP: added `force_input_to_fp32` flag for SyncBatchNorm [#659]
......
...@@ -146,7 +146,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -146,7 +146,7 @@ class FullyShardedDataParallel(nn.Module):
flatten_parameters (bool, Optional): flatten_parameters (bool, Optional):
if ``True``, flatten parameters into a single contiguous tensor, if ``True``, flatten parameters into a single contiguous tensor,
which improves training speed. which improves training speed.
cpu_offload (bool, Optional): move_params_to_cpu (bool, Optional):
if ``True``, offload FP32 params to CPU. This is only relevant when if ``True``, offload FP32 params to CPU. This is only relevant when
*``mixed_precision``* is ``True``. *``mixed_precision``* is ``True``.
compute_dtype (torch.dtype, Optional): compute_dtype (torch.dtype, Optional):
...@@ -215,6 +215,10 @@ class FullyShardedDataParallel(nn.Module): ...@@ -215,6 +215,10 @@ class FullyShardedDataParallel(nn.Module):
verbose (bool): verbose (bool):
Set this to ``True`` to turn on verbose output for model's string representation. Set this to ``True`` to turn on verbose output for model's string representation.
Default: False Default: False
cpu_offload (bool, Optional):
if ``True``, offload FP32 params to CPU. This is only relevant when
*``mixed_precision``* is ``True``. Note: This arg will be deprecated in favor of
*``move_params_to_cpu``* in an upcoming release.
""" """
def __init__( def __init__(
...@@ -225,7 +229,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -225,7 +229,7 @@ class FullyShardedDataParallel(nn.Module):
mixed_precision: bool = False, mixed_precision: bool = False,
fp32_reduce_scatter: bool = False, fp32_reduce_scatter: bool = False,
flatten_parameters: bool = True, flatten_parameters: bool = True,
cpu_offload: bool = False, move_params_to_cpu: bool = False,
compute_dtype: Optional[torch.dtype] = None, compute_dtype: Optional[torch.dtype] = None,
buffer_dtype: Optional[torch.dtype] = None, buffer_dtype: Optional[torch.dtype] = None,
move_grads_to_cpu: Optional[bool] = None, move_grads_to_cpu: Optional[bool] = None,
...@@ -236,6 +240,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -236,6 +240,7 @@ class FullyShardedDataParallel(nn.Module):
clear_autocast_cache: bool = False, clear_autocast_cache: bool = False,
force_input_to_fp32: bool = False, force_input_to_fp32: bool = False,
verbose: bool = False, verbose: bool = False,
cpu_offload: bool = False,
): ):
init_start = time.time() init_start = time.time()
super().__init__() super().__init__()
...@@ -246,10 +251,10 @@ class FullyShardedDataParallel(nn.Module): ...@@ -246,10 +251,10 @@ class FullyShardedDataParallel(nn.Module):
self.mixed_precision = mixed_precision self.mixed_precision = mixed_precision
self.fp32_reduce_scatter = fp32_reduce_scatter self.fp32_reduce_scatter = fp32_reduce_scatter
self.flatten_parameters = flatten_parameters self.flatten_parameters = flatten_parameters
self.cpu_offload = cpu_offload self.move_params_to_cpu = move_params_to_cpu or cpu_offload
self.compute_dtype = compute_dtype or (torch.float16 if mixed_precision else torch.float32) self.compute_dtype = compute_dtype or (torch.float16 if mixed_precision else torch.float32)
self.buffer_dtype = buffer_dtype or self.compute_dtype self.buffer_dtype = buffer_dtype or self.compute_dtype
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu self.move_grads_to_cpu = self.move_params_to_cpu if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb self.bucket_cap_mb = bucket_cap_mb
self.compute_device = compute_device or _get_default_cuda_device(module) self.compute_device = compute_device or _get_default_cuda_device(module)
self.uncollected_opt_state: Dict[int, Dict] = {} self.uncollected_opt_state: Dict[int, Dict] = {}
...@@ -267,7 +272,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -267,7 +272,7 @@ class FullyShardedDataParallel(nn.Module):
if self.fp32_reduce_scatter and not self.mixed_precision: if self.fp32_reduce_scatter and not self.mixed_precision:
raise ValueError("fp32_reduce_scatter requires mixed_precision=True") raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
if self.cpu_offload and not self.mixed_precision: if self.move_params_to_cpu and not self.mixed_precision:
raise ValueError("cpu_offload requires mixed_precision=True") raise ValueError("cpu_offload requires mixed_precision=True")
# skip validation if the process group was created above # skip validation if the process group was created above
...@@ -570,7 +575,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -570,7 +575,7 @@ class FullyShardedDataParallel(nn.Module):
f"buffer_dtype={self.buffer_dtype}, " f"buffer_dtype={self.buffer_dtype}, "
f"fp32_reduce_scatter={self.fp32_reduce_scatter}, " f"fp32_reduce_scatter={self.fp32_reduce_scatter}, "
f"compute_device={self.compute_device}" f"compute_device={self.compute_device}"
f"cpu_offload={self.cpu_offload}, " f"cpu_offload={self.move_params_to_cpu}, "
f"move_grads_to_cpu={self.move_grads_to_cpu}, " f"move_grads_to_cpu={self.move_grads_to_cpu}, "
f"bucket_cap_mb={self.bucket_cap_mb}, " f"bucket_cap_mb={self.bucket_cap_mb}, "
f"clear_autocast_cache={self.clear_autocast_cache}" f"clear_autocast_cache={self.clear_autocast_cache}"
...@@ -658,7 +663,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -658,7 +663,7 @@ class FullyShardedDataParallel(nn.Module):
else: else:
state_dict = super().state_dict(*args, **kwargs) state_dict = super().state_dict(*args, **kwargs)
if self.cpu_offload: if self.move_params_to_cpu:
for k in state_dict.keys(): for k in state_dict.keys():
state_dict[k] = state_dict[k].cpu() state_dict[k] = state_dict[k].cpu()
...@@ -909,7 +914,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -909,7 +914,7 @@ class FullyShardedDataParallel(nn.Module):
if self.mixed_precision: if self.mixed_precision:
assert p._fp32_shard.dtype == torch.float32 assert p._fp32_shard.dtype == torch.float32
if self.cpu_offload: if self.move_params_to_cpu:
assert p._fp32_shard.device == torch.device("cpu") assert p._fp32_shard.device == torch.device("cpu")
# If we plan to keep the FP32 parameters on CPU, then pinning # If we plan to keep the FP32 parameters on CPU, then pinning
# memory allows us to later use non-blocking transfers when moving # memory allows us to later use non-blocking transfers when moving
...@@ -1429,7 +1434,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1429,7 +1434,7 @@ class FullyShardedDataParallel(nn.Module):
if not p._is_sharded: # e.g., when world_size == 1 if not p._is_sharded: # e.g., when world_size == 1
update_p_data() update_p_data()
else: else:
# If self.cpu_offload and force_full_precision, we need to cast # If self.move_params_to_cpu and force_full_precision, we need to cast
# the FP32 CPU param to CUDA for the all-gather. # the FP32 CPU param to CUDA for the all-gather.
p_data = p.data.to(p._full_param_padded.device) p_data = p.data.to(p._full_param_padded.device)
...@@ -1717,6 +1722,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1717,6 +1722,11 @@ class FullyShardedDataParallel(nn.Module):
f"{msg} cur={torch.cuda.memory_allocated()/gb_denom: .4f} GB, max={torch.cuda.max_memory_allocated()/gb_denom: .4f} GB, t={time.time()-self._tstart: .1f}" f"{msg} cur={torch.cuda.memory_allocated()/gb_denom: .4f} GB, max={torch.cuda.max_memory_allocated()/gb_denom: .4f} GB, t={time.time()-self._tstart: .1f}"
) )
# Note: This property will be deprecated in an upcoming release in favor of `move_params_to_cpu`.
@property
def cpu_offload(self) -> bool:
return self.move_params_to_cpu
def _get_default_cuda_device(module: nn.Module) -> torch.device: def _get_default_cuda_device(module: nn.Module) -> torch.device:
"""Try to infer CUDA device from module parameters.""" """Try to infer CUDA device from module parameters."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment