Unverified Commit 861b5ce2 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] add clear_autocast_cache flag (#650)



* [fix] add clear_autocast_cache flag

- when training in AMP model with weight dtype32, FSDP may need to
  optionally clear the autocast cache to avoid GPU OOM
- this flag is default false, automatically doing it is a future TODO
- also added a verbose flag to make print(fsdp_model) a bit shorter
- updated the memory test to cover those new code
- added a couple of useful functions in parallel.py and testing.py

* minor

* address comments

* format

* improve the test
Co-authored-by: default avatarMin Xu <min.xu@acm.org>
parent 14d1f78c
...@@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD ## NEXT - TBD
### Fixed ### Fixed
- FSDP: workaround AMP autocast cache issue with clear\_autocast\_cache flag
- setup.py: hide CUDA extensions behind BUILD_CUDA_EXTENSIONS envvar - setup.py: hide CUDA extensions behind BUILD_CUDA_EXTENSIONS envvar
- SDP: re-expose the module property ([#647](https://github.com/facebookresearch/fairscale/pull/647)) - SDP: re-expose the module property ([#647](https://github.com/facebookresearch/fairscale/pull/647))
......
...@@ -195,6 +195,15 @@ class FullyShardedDataParallel(nn.Module): ...@@ -195,6 +195,15 @@ class FullyShardedDataParallel(nn.Module):
device for parameters returned by :func:`state_dict`. If not given, device for parameters returned by :func:`state_dict`. If not given,
this will default to ``compute_dtype``. Note that only the device this will default to ``compute_dtype``. Note that only the device
type will be respected (e.g., "cuda:0" and "cuda:1" are the same). type will be respected (e.g., "cuda:0" and "cuda:1" are the same).
clear_autocast_cache (bool):
When using mixed precision training with `torch.amp.autocast`, if the model weights
are in FP32, autocast maintains a cache for downcasted weights. The cache can cause
GPU OOM during the forward pass. Setting this flag to true will help clearing this
cache as inner FSDP instances finish part of the forward pass to save GPU memory.
Default: False
verbose (bool):
Set this to ``True`` to turn on verbose output for model's string representation.
Default: False
""" """
def __init__( def __init__(
...@@ -213,6 +222,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -213,6 +222,8 @@ class FullyShardedDataParallel(nn.Module):
compute_device: Optional[torch.device] = None, compute_device: Optional[torch.device] = None,
no_broadcast_optim_state: Optional[bool] = False, no_broadcast_optim_state: Optional[bool] = False,
state_dict_device: Optional[torch.device] = None, state_dict_device: Optional[torch.device] = None,
clear_autocast_cache: bool = False,
verbose: bool = False,
): ):
init_start = time.time() init_start = time.time()
super().__init__() super().__init__()
...@@ -232,6 +243,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -232,6 +243,8 @@ class FullyShardedDataParallel(nn.Module):
self.uncollected_opt_state: Dict[int, Dict] = {} self.uncollected_opt_state: Dict[int, Dict] = {}
self.no_broadcast_optim_state = no_broadcast_optim_state self.no_broadcast_optim_state = no_broadcast_optim_state
self.state_dict_device = state_dict_device or self.compute_device self.state_dict_device = state_dict_device or self.compute_device
self.clear_autocast_cache = clear_autocast_cache
self.verbose = verbose
self.gradient_predivide_factor: float = self._get_gradient_predivide_factor(self.world_size) self.gradient_predivide_factor: float = self._get_gradient_predivide_factor(self.world_size)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
...@@ -248,6 +261,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -248,6 +261,7 @@ class FullyShardedDataParallel(nn.Module):
if process_group: if process_group:
validate_process_group(self.compute_device, self.process_group) validate_process_group(self.compute_device, self.process_group)
# enable pytorch sync_bn just in case model contains sync_bn layers.
enable_pytorch_sync_bn(module) enable_pytorch_sync_bn(module)
# Only handle params which are not already sharded. This enables # Only handle params which are not already sharded. This enables
...@@ -301,7 +315,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -301,7 +315,7 @@ class FullyShardedDataParallel(nn.Module):
f"FSDP.__init__(done): total_init_time: {(init_end - init_start): .4f} num_params: {(sum(p.numel() for p in self.params))}" f"FSDP.__init__(done): total_init_time: {(init_end - init_start): .4f} num_params: {(sum(p.numel() for p in self.params))}"
) )
# Flag to guard multiple pre-forward hook being executed per iteration. # Flag to guard multiple pre-backward hook being executed per iteration.
# This is reset at the end of the backward pass. # This is reset at the end of the backward pass.
self._pre_backward_hook_has_run = False self._pre_backward_hook_has_run = False
...@@ -531,19 +545,24 @@ class FullyShardedDataParallel(nn.Module): ...@@ -531,19 +545,24 @@ class FullyShardedDataParallel(nn.Module):
return shard, num_to_pad return shard, num_to_pad
def extra_repr(self) -> str: def extra_repr(self) -> str:
return ( repr = (
f"rank={self.rank}, world_size={self.world_size}, " f"world_size={self.world_size}, "
f"reshard_after_forward={self.reshard_after_forward}, "
f"mixed_precision={self.mixed_precision}, "
f"fp32_reduce_scatter={self.fp32_reduce_scatter}, "
f"flatten_parameters={self.flatten_parameters}, " f"flatten_parameters={self.flatten_parameters}, "
f"cpu_offload={self.cpu_offload}, " f"mixed_precision={self.mixed_precision}, "
f"compute_dtype={self.compute_dtype}, "
f"buffer_dtype={self.buffer_dtype}, "
f"move_grads_to_cpu={self.move_grads_to_cpu}, "
f"bucket_cap_mb={self.bucket_cap_mb}, "
f"compute_device={self.compute_device}"
) )
if self.verbose:
repr = (
f"rank={self.rank}, " + repr + f"reshard_after_forward={self.reshard_after_forward}, "
f"compute_dtype={self.compute_dtype}, "
f"buffer_dtype={self.buffer_dtype}, "
f"fp32_reduce_scatter={self.fp32_reduce_scatter}, "
f"compute_device={self.compute_device}"
f"cpu_offload={self.cpu_offload}, "
f"move_grads_to_cpu={self.move_grads_to_cpu}, "
f"bucket_cap_mb={self.bucket_cap_mb}, "
f"clear_autocast_cache={self.clear_autocast_cache}"
)
return repr
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to wrapped module.""" """Forward missing attributes to wrapped module."""
...@@ -1001,6 +1020,12 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1001,6 +1020,12 @@ class FullyShardedDataParallel(nn.Module):
# Done with a forward pass. # Done with a forward pass.
self.training_state = TrainingState.IDLE self.training_state = TrainingState.IDLE
# Only need to clear cache during forward. During backward, the cache is not used.
# TODO (Min): Future PyTorch versions may provide a way to completely disable this
# cache. Update this when that's available.
if self.clear_autocast_cache:
torch.clear_autocast_cache()
return outputs return outputs
def _register_pre_backward_hooks(self, outputs: Any) -> Any: def _register_pre_backward_hooks(self, outputs: Any) -> Any:
...@@ -1454,7 +1479,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1454,7 +1479,7 @@ class FullyShardedDataParallel(nn.Module):
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
for p in params: for p in params:
if p._fp16_shard is not None: if p._fp16_shard is not None:
# _fp16_shard is allocated in _fp32_to_fp16_stream, so we can't # _fp16_shard is allocated in "fp32_to_fp16" stream, so we can't
# free it until the work in the current stream completes. # free it until the work in the current stream completes.
p._fp16_shard.record_stream(current_stream) p._fp16_shard.record_stream(current_stream)
free_storage_(p._fp16_shard) free_storage_(p._fp16_shard)
......
...@@ -56,3 +56,20 @@ def enable_pytorch_sync_bn(module: torch.nn.Module) -> None: ...@@ -56,3 +56,20 @@ def enable_pytorch_sync_bn(module: torch.nn.Module) -> None:
# used, but this call needs to be made to avoid an exception. # used, but this call needs to be made to avoid an exception.
# This function is removed from pytorch since 1.9. # This function is removed from pytorch since 1.9.
layer._specify_ddp_gpu_num(1) # type: ignore layer._specify_ddp_gpu_num(1) # type: ignore
def get_global_group() -> None:
"""
Singleton PyTorch distributed group.
Inspired by https://github.com/pytorch/fairseq
For FSDP, it is important to use a global group, otherwise, inner FSDP instances
will not share the gradient reduction bucket buffer with the root instance, end up using
more GPU memory.
"""
if dist.is_initialized():
if not hasattr(get_global_group, "_global_group"):
get_global_group._global_group = dist.new_group() # type: ignore
return get_global_group._global_group # type: ignore
else:
return None
...@@ -28,6 +28,7 @@ relative imports. ...@@ -28,6 +28,7 @@ relative imports.
import contextlib import contextlib
import functools import functools
import gc
import inspect import inspect
import logging import logging
import multiprocessing import multiprocessing
...@@ -666,3 +667,17 @@ def temp_files_ctx(num: int) -> Generator: ...@@ -666,3 +667,17 @@ def temp_files_ctx(num: int) -> Generator:
# temp files could have been removed, so we use rmf. # temp files could have been removed, so we use rmf.
for name in files: for name in files:
rmf(name) rmf(name)
def dump_all_tensors(rank: int) -> None:
"""Useful tool for debugging memory issues from the python side."""
if rank != 0:
return
for obj in gc.get_objects():
try:
ttype = str(type(obj))
if torch.is_tensor(obj) or (hasattr(obj, "data") and torch.is_tensor(obj.data)):
print(ttype, obj.shape, obj.dtype, obj.device, obj.storage().size())
except Exception as e:
pass
print(torch.cuda.memory_summary())
...@@ -1920,6 +1920,7 @@ def set_default_dtype(d : _dtype) -> None: ... ...@@ -1920,6 +1920,7 @@ def set_default_dtype(d : _dtype) -> None: ...
def manager_path() -> str: ... def manager_path() -> str: ...
def compiled_with_cxx11_abi() -> _bool: ... def compiled_with_cxx11_abi() -> _bool: ...
def is_autocast_enabled() -> _bool: ... def is_autocast_enabled() -> _bool: ...
def clear_autocast_cache() -> None: ...
# The return value of this function depends on the value of `as_tuple`, # The return value of this function depends on the value of `as_tuple`,
# (similar to `unique`, `lu`, etc.); as such, it is not # (similar to `unique`, `lu`, etc.); as such, it is not
......
...@@ -45,6 +45,7 @@ def reset_peak_memory_stats(device: Union[_device_t, int] = None) -> None: ... ...@@ -45,6 +45,7 @@ def reset_peak_memory_stats(device: Union[_device_t, int] = None) -> None: ...
def memory_cached(device: Optional[_device_t]=...) -> int: ... def memory_cached(device: Optional[_device_t]=...) -> int: ...
def max_memory_cached(device: Optional[_device_t]=...) -> int: ... def max_memory_cached(device: Optional[_device_t]=...) -> int: ...
def reset_max_memory_cached(device: Optional[_device_t]=...) -> None: ... def reset_max_memory_cached(device: Optional[_device_t]=...) -> None: ...
def memory_summary() -> str: ...
def cudart() -> ctypes.CDLL: ... def cudart() -> ctypes.CDLL: ...
def find_cuda_windows_lib() -> Optional[ctypes.CDLL]: ... def find_cuda_windows_lib() -> Optional[ctypes.CDLL]: ...
#MODIFIED BY TORCHGPIPE #MODIFIED BY TORCHGPIPE
......
...@@ -9,11 +9,10 @@ ...@@ -9,11 +9,10 @@
""" Test FSDP with GPU memory usage. """ """ Test FSDP with GPU memory usage. """
import gc import contextlib
import pytest import pytest
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
...@@ -22,38 +21,19 @@ import torch.optim as optim ...@@ -22,38 +21,19 @@ import torch.optim as optim
from fairscale.nn import checkpoint_wrapper from fairscale.nn import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import auto_wrap_bn from fairscale.nn.data_parallel import auto_wrap_bn
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx, torch_version from fairscale.utils.parallel import get_global_group
from fairscale.utils.testing import (
dist_init,
def get_global_group(): dump_all_tensors,
""" skip_if_single_gpu,
Singleton pytorch distributed group teardown,
Inspired by https://github.com/pytorch/fairseq temp_files_ctx,
""" torch_version,
if dist.is_initialized(): )
if not hasattr(get_global_group, "_global_group"):
get_global_group._global_group = dist.new_group()
return get_global_group._global_group
else:
return None
def to_fsdp(module):
return FSDP(module, process_group=get_global_group())
def to_fsdp(module, fsdp_config):
def dump_all_tensors(rank): return FSDP(module, process_group=get_global_group(), **fsdp_config)
"""Use this for debugging"""
if rank != 0:
return
for obj in gc.get_objects():
try:
# Only need to check parameter type objects if asked.
ttype = str(type(obj))
if torch.is_tensor(obj) or (hasattr(obj, "data") and torch.is_tensor(obj.data)):
print(ttype, obj.shape, obj.dtype, obj.device, id(obj), obj.storage().size())
except Exception as e:
pass
def get_cur_mem(rank, result, prefix): def get_cur_mem(rank, result, prefix):
...@@ -62,45 +42,51 @@ def get_cur_mem(rank, result, prefix): ...@@ -62,45 +42,51 @@ def get_cur_mem(rank, result, prefix):
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self, hidden_dim):
super().__init__() super().__init__()
# TODO (Min): for both fast and memory efficient conv kernels, we should be using
# AMP/fp16 + channel_last input format. Otherwise, cudnn internally does conversion
# to channel_last when it is fp16 weights. Leave this knowledge here and perhaps
# future test can cover it.
self.stem = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) self.stem = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.blocks = nn.Sequential( self.blocks = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=5, padding=2), nn.Conv2d(64, hidden_dim, kernel_size=5, padding=2),
nn.BatchNorm2d(128), nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=5, padding=2), nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
nn.BatchNorm2d(128), nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=5, padding=2), nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
nn.BatchNorm2d(128), nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten(), nn.Flatten(),
) )
self.head = nn.Linear(128, 10) self.head = nn.Linear(hidden_dim, 10)
def forward(self, x): def forward(self, x):
return self.head(self.blocks(self.stem(x))) return self.head(self.blocks(self.stem(x)))
def create_model(with_fsdp, with_checkpoint): def create_model(with_fsdp, with_checkpoint, model_hidden_dim, fsdp_config):
model = Model() model = Model(model_hidden_dim)
if with_fsdp: if with_fsdp:
model.stem = auto_wrap_bn(model.stem, single_rank_pg=False) model.stem = auto_wrap_bn(model.stem, single_rank_pg=False)
model.blocks = auto_wrap_bn(model.blocks, single_rank_pg=False) model.blocks = auto_wrap_bn(model.blocks, single_rank_pg=False)
if with_checkpoint: if with_checkpoint:
model.blocks = checkpoint_wrapper(model.blocks) model.blocks = checkpoint_wrapper(model.blocks)
model.stem = to_fsdp(model.stem) model.stem = to_fsdp(model.stem, fsdp_config)
model.blocks = to_fsdp(model.blocks) model.blocks = to_fsdp(model.blocks, fsdp_config)
model.head = to_fsdp(model.head) model.head = to_fsdp(model.head, fsdp_config)
else: else:
if with_checkpoint: if with_checkpoint:
model.blocks = checkpoint_wrapper(model.blocks) model.blocks = checkpoint_wrapper(model.blocks)
return model return model
def _distributed_worker(gpu_id, world_size, with_fsdp, with_checkpoint, filename, filename_rpc, expected): def _distributed_worker(
gpu_id, world_size, with_fsdp, with_checkpoint, filename, filename_rpc, expected, model_hidden_dim, fsdp_config
):
torch.cuda.set_device(gpu_id) torch.cuda.set_device(gpu_id)
rank = gpu_id rank = gpu_id
...@@ -109,28 +95,42 @@ def _distributed_worker(gpu_id, world_size, with_fsdp, with_checkpoint, filename ...@@ -109,28 +95,42 @@ def _distributed_worker(gpu_id, world_size, with_fsdp, with_checkpoint, filename
torch.manual_seed(0) torch.manual_seed(0)
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
# Note that FSDP auto-cast the input in AMP mode. So we don't need to call half() here.
batch = torch.randn(size=(2, 3, 224, 224)).cuda() batch = torch.randn(size=(2, 3, 224, 224)).cuda()
model = create_model(with_fsdp, with_checkpoint) model = create_model(with_fsdp, with_checkpoint, model_hidden_dim, fsdp_config)
model = model.cuda() model = model.cuda()
if with_fsdp: if with_fsdp:
model = to_fsdp(model) model = to_fsdp(model, fsdp_config)
else: else:
model = DistributedDataParallel(model, device_ids=[gpu_id], bucket_cap_mb=500) model = DistributedDataParallel(model, device_ids=[gpu_id], bucket_cap_mb=500)
# We enable momentum so that after the first iteration, the optimizer state is added
# to the total memory used.
criterion = nn.MSELoss() criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-4) optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
results = {} # Set AMP context if needed.
for iteration in range(3): context = contextlib.suppress()
if "mixed_precision" in fsdp_config and fsdp_config["mixed_precision"]:
context = torch.cuda.amp.autocast(enabled=True)
# We have observed that sometimes after 3rd iteration, 4th one can fail (not on this
# test but on much bigger scale tests). We run 4 iterations here just in case it happens.
iterations = 4
results = {} # results of memory stats
for iteration in range(iterations):
get_cur_mem(gpu_id, results, f"iter {iteration}: start") get_cur_mem(gpu_id, results, f"iter {iteration}: start")
out = model(batch) with context:
get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd") out = model(batch)
get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd")
out = sum(o.sum() for o in out[0]) out = sum(o.sum() for o in out[0])
fake_loss = criterion(out, torch.tensor(0.0).cuda()) fake_loss = criterion(out, torch.tensor(0.0).cuda())
get_cur_mem(gpu_id, results, f"iter {iteration}: after loss") get_cur_mem(gpu_id, results, f"iter {iteration}: after loss")
fake_loss.backward() fake_loss.backward()
get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd") get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd")
...@@ -146,14 +146,27 @@ def _distributed_worker(gpu_id, world_size, with_fsdp, with_checkpoint, filename ...@@ -146,14 +146,27 @@ def _distributed_worker(gpu_id, world_size, with_fsdp, with_checkpoint, filename
p.grad = None p.grad = None
get_cur_mem(gpu_id, results, f"iter {iteration}: done") get_cur_mem(gpu_id, results, f"iter {iteration}: done")
assert results == expected, f"{results} but expected {expected}" dump_all_tensors(gpu_id)
print(results)
def cmp(results, expected):
ret = ""
assert results.keys() == expected.keys(), f"{list(results.keys())} vs. {list(expected.keys())}"
for k, v in results.items():
exp = expected[k]
if abs(exp - v) > 1: # allow 1MB rounding differences
ret += f"{k}: got {v}, expected {exp}\n"
return ret
output = cmp(results, expected)
assert not output, output
teardown() teardown()
@skip_if_single_gpu @skip_if_single_gpu
@pytest.mark.parametrize("ckpt", ["no_ckpt", "ckpt"]) @pytest.mark.parametrize("ckpt", ["no_ckpt", "ckpt"])
@pytest.mark.parametrize("fsdp", ["ddp", "fsdp"]) @pytest.mark.parametrize("fsdp", ["ddp", "fsdp", "fsdp_amp_default", "fsdp_amp_compute_dtype32"])
def test_fsdp_memory(fsdp, ckpt): def test_fsdp_memory(fsdp, ckpt):
expected = { expected = {
("ddp", "no_ckpt"): { ("ddp", "no_ckpt"): {
...@@ -161,86 +174,244 @@ def test_fsdp_memory(fsdp, ckpt): ...@@ -161,86 +174,244 @@ def test_fsdp_memory(fsdp, ckpt):
"iter 0: after fwd": 346, "iter 0: after fwd": 346,
"iter 0: after loss": 346, "iter 0: after loss": 346,
"iter 0: after bwd": 14, "iter 0: after bwd": 14,
"iter 0: after step": 14, "iter 0: after step": 17,
"iter 0: done": 9, "iter 0: done": 13,
"iter 1: start": 9, "iter 1: start": 13,
"iter 1: after fwd": 346, "iter 1: after fwd": 350,
"iter 1: after loss": 346, "iter 1: after loss": 350,
"iter 1: after bwd": 14, "iter 1: after bwd": 17,
"iter 1: after step": 14, "iter 1: after step": 17,
"iter 1: done": 9, "iter 1: done": 13,
"iter 2: start": 9, "iter 2: start": 13,
"iter 2: after fwd": 346, "iter 2: after fwd": 350,
"iter 2: after loss": 346, "iter 2: after loss": 350,
"iter 2: after bwd": 14, "iter 2: after bwd": 17,
"iter 2: after step": 14, "iter 2: after step": 17,
"iter 2: done": 9, "iter 2: done": 13,
"iter 3: start": 13,
"iter 3: after fwd": 350,
"iter 3: after loss": 350,
"iter 3: after bwd": 17,
"iter 3: after step": 17,
"iter 3: done": 13,
}, },
("fsdp", "no_ckpt"): { ("fsdp", "no_ckpt"): {
"iter 0: start": 3, "iter 0: start": 3,
"iter 0: after fwd": 340, "iter 0: after fwd": 340,
"iter 0: after loss": 340, "iter 0: after loss": 340,
"iter 0: after bwd": 66, "iter 0: after bwd": 66,
"iter 0: after step": 66, "iter 0: after step": 68,
"iter 0: done": 3, "iter 0: done": 5,
"iter 1: start": 3, "iter 1: start": 5,
"iter 1: after fwd": 340, "iter 1: after fwd": 342,
"iter 1: after loss": 340, "iter 1: after loss": 342,
"iter 1: after bwd": 66, "iter 1: after bwd": 68,
"iter 1: after step": 66, "iter 1: after step": 68,
"iter 1: done": 3, "iter 1: done": 5,
"iter 2: start": 3, "iter 2: start": 5,
"iter 2: after fwd": 340, "iter 2: after fwd": 342,
"iter 2: after loss": 340, "iter 2: after loss": 342,
"iter 2: after bwd": 66, "iter 2: after bwd": 68,
"iter 2: after step": 66, "iter 2: after step": 68,
"iter 2: done": 3, "iter 2: done": 5,
"iter 3: start": 5,
"iter 3: after fwd": 342,
"iter 3: after loss": 342,
"iter 3: after bwd": 68,
"iter 3: after step": 68,
"iter 3: done": 5,
},
("fsdp_amp_default", "no_ckpt"): {
"iter 0: start": 28,
"iter 0: after fwd": 630,
"iter 0: after loss": 630,
"iter 0: after bwd": 104,
"iter 0: after step": 131,
"iter 0: done": 54,
"iter 1: start": 54,
"iter 1: after fwd": 657,
"iter 1: after loss": 657,
"iter 1: after bwd": 131,
"iter 1: after step": 131,
"iter 1: done": 54,
"iter 2: start": 54,
"iter 2: after fwd": 657,
"iter 2: after loss": 657,
"iter 2: after bwd": 131,
"iter 2: after step": 131,
"iter 2: done": 54,
"iter 3: start": 54,
"iter 3: after fwd": 657,
"iter 3: after loss": 657,
"iter 3: after bwd": 131,
"iter 3: after step": 131,
"iter 3: done": 54,
},
("fsdp_amp_compute_dtype32", "no_ckpt"): {
"iter 0: start": 28,
"iter 0: after fwd": 657,
"iter 0: after loss": 657,
"iter 0: after bwd": 117,
"iter 0: after step": 143,
"iter 0: done": 54,
"iter 1: start": 54,
"iter 1: after fwd": 684,
"iter 1: after loss": 684,
"iter 1: after bwd": 143,
"iter 1: after step": 143,
"iter 1: done": 54,
"iter 2: start": 54,
"iter 2: after fwd": 684,
"iter 2: after loss": 684,
"iter 2: after bwd": 143,
"iter 2: after step": 143,
"iter 2: done": 54,
"iter 3: start": 54,
"iter 3: after fwd": 684,
"iter 3: after loss": 684,
"iter 3: after bwd": 143,
"iter 3: after step": 143,
"iter 3: done": 54,
}, },
("ddp", "ckpt"): { ("ddp", "ckpt"): {
"iter 0: start": 9, "iter 0: start": 9,
"iter 0: after fwd": 57, "iter 0: after fwd": 57,
"iter 0: after loss": 57, "iter 0: after loss": 57,
"iter 0: after bwd": 14, "iter 0: after bwd": 14,
"iter 0: after step": 14, "iter 0: after step": 17,
"iter 0: done": 9, "iter 0: done": 13,
"iter 1: start": 9, "iter 1: start": 13,
"iter 1: after fwd": 57, "iter 1: after fwd": 61,
"iter 1: after loss": 57, "iter 1: after loss": 61,
"iter 1: after bwd": 14, "iter 1: after bwd": 17,
"iter 1: after step": 14, "iter 1: after step": 17,
"iter 1: done": 9, "iter 1: done": 13,
"iter 2: start": 9, "iter 2: start": 13,
"iter 2: after fwd": 57, "iter 2: after fwd": 61,
"iter 2: after loss": 57, "iter 2: after loss": 61,
"iter 2: after bwd": 14, "iter 2: after bwd": 17,
"iter 2: after step": 14, "iter 2: after step": 17,
"iter 2: done": 9, "iter 2: done": 13,
"iter 3: start": 13,
"iter 3: after fwd": 61,
"iter 3: after loss": 61,
"iter 3: after bwd": 17,
"iter 3: after step": 17,
"iter 3: done": 13,
}, },
("fsdp", "ckpt"): { ("fsdp", "ckpt"): {
"iter 0: start": 3, "iter 0: start": 3,
"iter 0: after fwd": 51, "iter 0: after fwd": 51,
"iter 0: after loss": 51, "iter 0: after loss": 51,
"iter 0: after bwd": 66, "iter 0: after bwd": 66,
"iter 0: after step": 66, "iter 0: after step": 68,
"iter 0: done": 3, "iter 0: done": 5,
"iter 1: start": 3, "iter 1: start": 5,
"iter 1: after fwd": 51, "iter 1: after fwd": 53,
"iter 1: after loss": 51, "iter 1: after loss": 53,
"iter 1: after bwd": 66, "iter 1: after bwd": 68,
"iter 1: after step": 66, "iter 1: after step": 68,
"iter 1: done": 3, "iter 1: done": 5,
"iter 2: start": 3, "iter 2: start": 5,
"iter 2: after fwd": 51, "iter 2: after fwd": 53,
"iter 2: after loss": 51, "iter 2: after loss": 53,
"iter 2: after bwd": 66, "iter 2: after bwd": 68,
"iter 2: after step": 66, "iter 2: after step": 68,
"iter 2: done": 3, "iter 2: done": 5,
"iter 3: start": 5,
"iter 3: after fwd": 53,
"iter 3: after loss": 53,
"iter 3: after bwd": 68,
"iter 3: after step": 68,
"iter 3: done": 5,
},
("fsdp_amp_default", "ckpt"): {
"iter 0: start": 28,
"iter 0: after fwd": 52,
"iter 0: after loss": 52,
"iter 0: after bwd": 104,
"iter 0: after step": 131,
"iter 0: done": 54,
"iter 1: start": 54,
"iter 1: after fwd": 79,
"iter 1: after loss": 79,
"iter 1: after bwd": 131,
"iter 1: after step": 131,
"iter 1: done": 54,
"iter 2: start": 54,
"iter 2: after fwd": 79,
"iter 2: after loss": 79,
"iter 2: after bwd": 131,
"iter 2: after step": 131,
"iter 2: done": 54,
"iter 3: start": 54,
"iter 3: after fwd": 79,
"iter 3: after loss": 79,
"iter 3: after bwd": 131,
"iter 3: after step": 131,
"iter 3: done": 54,
},
("fsdp_amp_compute_dtype32", "ckpt"): {
"iter 0: start": 28,
"iter 0: after fwd": 52,
"iter 0: after loss": 52,
"iter 0: after bwd": 117,
"iter 0: after step": 143,
"iter 0: done": 54,
"iter 1: start": 54,
"iter 1: after fwd": 79,
"iter 1: after loss": 79,
"iter 1: after bwd": 143,
"iter 1: after step": 143,
"iter 1: done": 54,
"iter 2: start": 54,
"iter 2: after fwd": 79,
"iter 2: after loss": 79,
"iter 2: after bwd": 143,
"iter 2: after step": 143,
"iter 2: done": 54,
"iter 3: start": 54,
"iter 3: after fwd": 79,
"iter 3: after loss": 79,
"iter 3: after bwd": 143,
"iter 3: after step": 143,
"iter 3: done": 54,
}, },
}[(fsdp, ckpt)] }[(fsdp, ckpt)]
fsdp = fsdp == "fsdp"
ckpt = ckpt == "ckpt" # Compute the FSDP config.
fsdp_config = {}
# Set mixed precision.
if "amp" in fsdp:
fsdp_config["mixed_precision"] = True
# When compute_dtype is FP32, make sure we use clear_autocast_cache.
# Setting fp32_reduce_scatter and verbose for more code coverage.
if "compute_dtype32" in fsdp:
fsdp_config["compute_dtype"] = torch.float32
fsdp_config["fp32_reduce_scatter"] = True
fsdp_config["clear_autocast_cache"] = True
fsdp_config["verbose"] = True
# Using bigger hidden dimension for AMP to increase the model size
# so that bug in handling params will show up but we don't do that
# in the base case to keep the test fast.
# - hidden_dim 128: model size ~4MB
# - hidden_dim 512: model size ~55MB
# - hidden_dim 1024: model size ~200MB (seems to be too big for CI tests though)
model_hidden_dim = 128
if "amp" in fsdp:
model_hidden_dim = 512
# Get the fsdp and checkpoint flags.
with_fsdp = "fsdp" in fsdp
with_ckpt = ckpt == "ckpt"
world_size = 2 world_size = 2
with temp_files_ctx(num=2) as temp_files: with temp_files_ctx(num=2) as temp_files:
mp.spawn( mp.spawn(
_distributed_worker, (world_size, fsdp, ckpt, temp_files[0], temp_files[1], expected), nprocs=world_size _distributed_worker,
(world_size, with_fsdp, with_ckpt, temp_files[0], temp_files[1], expected, model_hidden_dim, fsdp_config),
nprocs=world_size,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment