Unverified Commit a0458b98 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] FSDP: disable single rank process group for auto_wrap_bn and fixed mixed...

[fix] FSDP: disable single rank process group for auto_wrap_bn and fixed mixed precision regnet test (#556)

* [fix] disable single rank process group for auto_wrap_bn

- beefed up unit test with regnet-like model
- found that single-rank process group is causing problem
- disabled it to enable convergence tests on the vissl side
- use `raise e from None` to get a better assertion output
  in testing.py.

* [test] fix regnet test for ddp+mixed_precision

- need AMP context in FSDP
- workaround different between ddp & fsdp when bias=True
- fixed a bug in input data generation that caused different ranks have
  the same data with wrong iteration count.
- added TODO for need a better loss and grad_scaler and reduced
  iters so there is no nan.
- added a (disabled) debugging code

* lint

* lint

* add scaler

* lint

* scaler

* add a real loss

* seeding in the ranks

* blance tests

* run AMP DDP==FSDP test only on cuda version 11 and up

* add relu inplace and comment

* make wrap_bn covers more cases in full precision mode
parent acb9ef00
...@@ -1517,7 +1517,7 @@ def _pre_load_state_dict_hook( ...@@ -1517,7 +1517,7 @@ def _pre_load_state_dict_hook(
######################################################################################## ########################################################################################
def auto_wrap_bn(module: nn.Module) -> nn.Module: def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False) -> nn.Module:
""" """
Auto wrap all BatchNorm (BN) instances with a safer FSDP, esp. when convert Auto wrap all BatchNorm (BN) instances with a safer FSDP, esp. when convert
to sync BN is used and the outer FSDP is flattening. to sync BN is used and the outer FSDP is flattening.
...@@ -1531,6 +1531,9 @@ def auto_wrap_bn(module: nn.Module) -> nn.Module: ...@@ -1531,6 +1531,9 @@ def auto_wrap_bn(module: nn.Module) -> nn.Module:
Args: Args:
module (nn.Module): module (nn.Module):
The model (or part of the model) in which BN to be pre-wrapped. The model (or part of the model) in which BN to be pre-wrapped.
single_rank_pg (bool):
If true, put BNs in a single-rank process group. Default False.
This might be needed for Apex sync BN support. Still under construction.
Returns: Returns:
Processed module, where BNs are wrapped with a special FSDP instance. Processed module, where BNs are wrapped with a special FSDP instance.
...@@ -1543,10 +1546,15 @@ def auto_wrap_bn(module: nn.Module) -> nn.Module: ...@@ -1543,10 +1546,15 @@ def auto_wrap_bn(module: nn.Module) -> nn.Module:
else: else:
return is_bn and not isinstance(module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES)) # type: ignore return is_bn and not isinstance(module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES)) # type: ignore
my_rank = dist.get_rank() pg = None
if single_rank_pg:
# No sharding with this single member group.
my_rank = dist.get_rank()
pg = dist.new_group(ranks=[my_rank])
fsdp_config = { fsdp_config = {
"wrapper_cls": FullyShardedDataParallel, "wrapper_cls": FullyShardedDataParallel,
"process_group": dist.new_group(ranks=[my_rank]), # No sharding with this single member group. "process_group": pg,
"mixed_precision": False, # Keep the weights in FP32. "mixed_precision": False, # Keep the weights in FP32.
"flatten_parameters": False, # Do not flatten. "flatten_parameters": False, # Do not flatten.
} }
......
...@@ -32,6 +32,7 @@ import logging ...@@ -32,6 +32,7 @@ import logging
import multiprocessing import multiprocessing
import os import os
import random import random
import subprocess
import sys import sys
import tempfile import tempfile
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
...@@ -116,6 +117,30 @@ def torch_version() -> Tuple[int, ...]: ...@@ -116,6 +117,30 @@ def torch_version() -> Tuple[int, ...]:
return tuple(int(n) for n in numbering) return tuple(int(n) for n in numbering)
_smi_ver = None
def torch_cuda_version(compiled: bool = False) -> Tuple[int, ...]:
if compiled:
numbering = torch.version.cuda.split(".")[:2]
else:
global _smi_ver
if _smi_ver is None:
def get_smi_ver() -> str:
"""Get CUDA version from nvidia-smi"""
for line in subprocess.check_output("nvidia-smi".split()).decode("utf-8").split("\n"):
if "CUDA Version" in line:
res = line.split()[8]
assert res.startswith("10.") or res.startswith("11."), res
return res
assert False
_smi_ver = get_smi_ver()
numbering = _smi_ver.split(".")[:2]
return tuple(int(n) for n in numbering)
def dist_init(rank: int, world_size: int, filename: str, filename_rpc: str = "") -> bool: def dist_init(rank: int, world_size: int, filename: str, filename_rpc: str = "") -> bool:
""" """
Initialize torch distributed, based on a temporary file shared across ranks, which makes it possible for unrelated Initialize torch distributed, based on a temporary file shared across ranks, which makes it possible for unrelated
...@@ -445,7 +470,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: O ...@@ -445,7 +470,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: O
# Add dict key to the assertion error. # Add dict key to the assertion error.
msg = e.args[0] msg = e.args[0]
new_msg = f"For dict key '{dict_key}': {msg}" new_msg = f"For dict key '{dict_key}': {msg}"
raise AssertionError(new_msg) raise AssertionError(new_msg) from None
else: else:
raise e raise e
else: else:
......
...@@ -33,4 +33,8 @@ tests/nn/pipe/test_deferred_batch_norm.py ...@@ -33,4 +33,8 @@ tests/nn/pipe/test_deferred_batch_norm.py
tests/nn/pipe/test_dependency.py tests/nn/pipe/test_dependency.py
tests/nn/pipe/test_stream.py tests/nn/pipe/test_stream.py
tests/experimental/nn/test_multiprocess_pipe.py tests/experimental/nn/test_multiprocess_pipe.py
tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_top2gating.py
tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
tests/experimental/nn/test_offload.py
tests/nn/data_parallel/test_fsdp_apply.py tests/nn/data_parallel/test_fsdp_apply.py
tests/nn/data_parallel/test_fsdp_regnet.py
tests/nn/data_parallel/test_fsdp_uneven.py tests/nn/data_parallel/test_fsdp_uneven.py
tests/nn/data_parallel/test_fsdp_grad_scaler.py tests/nn/data_parallel/test_fsdp_grad_scaler.py
tests/nn/data_parallel/test_fsdp_no_sync.py tests/nn/data_parallel/test_fsdp_no_sync.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_input.py tests/nn/data_parallel/test_fsdp_input.py
tests/nn/data_parallel/test_fsdp_multiple_forward.py tests/nn/data_parallel/test_fsdp_multiple_forward.py
tests/nn/data_parallel/test_fsdp_regnet.py
tests/nn/data_parallel/test_fsdp_optimizer_utils.py tests/nn/data_parallel/test_fsdp_optimizer_utils.py
tests/nn/data_parallel/test_sharded_ddp_features.py tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
...@@ -17,12 +17,8 @@ tests/nn/pipe/skip/test_portal.py ...@@ -17,12 +17,8 @@ tests/nn/pipe/skip/test_portal.py
tests/nn/pipe/skip/test_tracker.py tests/nn/pipe/skip/test_tracker.py
tests/nn/pipe/skip/test_inspect_skip_layout.py tests/nn/pipe/skip/test_inspect_skip_layout.py
tests/nn/pipe/test_checkpoint_ddp.py tests/nn/pipe/test_checkpoint_ddp.py
tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_top2gating.py
tests/optim/test_single_node_adascale.py tests/optim/test_single_node_adascale.py
tests/optim/test_adam.py tests/optim/test_adam.py
tests/optim/test_oss.py tests/optim/test_oss.py
tests/optim/test_oss_adascale.py tests/optim/test_oss_adascale.py
tests/optim/test_ddp_adascale.py tests/optim/test_ddp_adascale.py
tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
tests/experimental/nn/test_offload.py
...@@ -15,13 +15,26 @@ import tempfile ...@@ -15,13 +15,26 @@ import tempfile
import pytest import pytest
import torch import torch
from torch.cuda.amp import GradScaler
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.nn import BatchNorm2d, Conv2d, Module, SyncBatchNorm from torch.nn import (
AdaptiveAvgPool2d,
BatchNorm2d,
Conv2d,
CrossEntropyLoss,
Linear,
Module,
ReLU,
Sequential,
Sigmoid,
SyncBatchNorm,
)
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState, auto_wrap_bn from fairscale.nn.data_parallel import TrainingState, auto_wrap_bn
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils.testing import ( from fairscale.utils.testing import (
dist_init, dist_init,
objects_are_equal, objects_are_equal,
...@@ -29,21 +42,96 @@ from fairscale.utils.testing import ( ...@@ -29,21 +42,96 @@ from fairscale.utils.testing import (
skip_if_single_gpu, skip_if_single_gpu,
state_dict_norm, state_dict_norm,
teardown, teardown,
torch_cuda_version,
torch_version, torch_version,
) )
# Const test params.
# Reduce iterations to 1 for debugging.
# Change world_size to 8 on beefy machines for better test coverage.
_world_size = 2
_iterations = 5
# Cover different ReLU flavor. This will cause DDP and FSDP models to have
# different ReLUs since they will different random flags.
_relu_inplace = True
if random.randint(0, 1) == 0:
_relu_inplace = False
# TODO (Min): test apex BN when available in the future.
try:
import apex
apex_bn_converter = apex.parallel.convert_syncbn_model
except ImportError:
apex_bn_converter = None
pytorch_bn_converter = SyncBatchNorm.convert_sync_batchnorm # type: ignore
_bn_converter = pytorch_bn_converter
_single_rank_pg = False
class ResBlock(Module):
"""Conv block in regnet with residual connection."""
def __init__(self, width_in, width_out):
super().__init__()
self.proj = Conv2d(width_in, width_out, (1, 1), (2, 2), bias=False)
self.bn = BatchNorm2d(width_out)
self.f = Sequential(
Sequential( # block a
Conv2d(width_in, width_out, (1, 1), (1, 1), bias=False), BatchNorm2d(width_out), ReLU(_relu_inplace),
),
Sequential( # block b
Conv2d(width_out, width_out, (3, 3), (2, 2), (1, 1), groups=2, bias=False),
BatchNorm2d(width_out),
ReLU(_relu_inplace),
),
Sequential( # block se
AdaptiveAvgPool2d((1, 1)),
Sequential(
Conv2d(width_out, 2, (1, 1), (1, 1), bias=False),
ReLU(_relu_inplace),
Conv2d(2, width_out, (1, 1), (1, 1), bias=False),
Sigmoid(),
),
),
Conv2d(width_out, width_out, (1, 1), (1, 1), bias=False), # block c
BatchNorm2d(width_out), # final_bn
)
self.relu = ReLU()
self.need_fsdp_wrap = True
def forward(self, x):
x = self.bn(self.proj(x)) + self.f(x)
return self.relu(x)
class Model(Module): class Model(Module):
"""SSL model with trunk and head."""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
# TODO (Min): for now, we just test pytorch sync_bn here. print(f"Using relu inplace: {_relu_inplace}")
# this will grow into regnet; testing apex sync_bn, etc.
self.conv = Conv2d(2, 2, (1, 1)) self.trunk = Sequential()
self.bn = BatchNorm2d(2) self.trunk.need_fsdp_wrap = True # Set a flag for later wrapping.
stem = Sequential(Conv2d(2, 4, (3, 3), (2, 2), (1, 1), bias=False), BatchNorm2d(4), ReLU(_relu_inplace))
any_stage_block1_0 = ResBlock(4, 8)
self.trunk.add_module("stem", stem)
self.trunk.add_module("any_stage_block1", Sequential(any_stage_block1_0))
self.head = Sequential(
# TODO (Min): FSDP-mixed_precision doesn't compute the same ways as DDP AMP when bias=True.
# so, we use bias=False for now in the projection_head.
# The Conv2d layers above does not use bias in regnet, but even if they use
# bias, FSDP and DDP seem to agree on how it is computed.
Sequential(Linear(16, 16, bias=False), ReLU(), Linear(16, 8, bias=False),), # projection_head
Linear(8, 15, bias=False), # prototypes0
)
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.trunk(x).reshape(-1)
x = self.bn(x) x = self.head(x)
return x return x
...@@ -67,9 +155,10 @@ def ddp_ref(): ...@@ -67,9 +155,10 @@ def ddp_ref():
state_before = model.state_dict() state_before = model.state_dict()
# Get reference inputs per rank. # Get reference inputs per rank.
world_size = 2 world_size = _world_size
iterations = 100 iterations = _iterations
inputs = [[]] * world_size print(f"Getting DDP reference for world_size {world_size} and iterations {iterations}")
inputs = [[] for i in range(world_size)]
for rank in range(world_size): for rank in range(world_size):
for i in range(iterations): for i in range(iterations):
inputs[rank].append(torch.rand(2, 2, 2, 2)) inputs[rank].append(torch.rand(2, 2, 2, 2))
...@@ -86,6 +175,7 @@ def ddp_ref(): ...@@ -86,6 +175,7 @@ def ddp_ref():
args=( args=(
world_size, world_size,
fsdp_config, fsdp_config,
None,
precision == "mixed", precision == "mixed",
temp_file_name, temp_file_name,
unused, unused,
...@@ -128,6 +218,7 @@ def _test_func( ...@@ -128,6 +218,7 @@ def _test_func(
rank, rank,
world_size, world_size,
fsdp_config, fsdp_config,
fsdp_wrap_bn,
ddp_mixed_precision, ddp_mixed_precision,
tempfile_name, tempfile_name,
unused, unused,
...@@ -143,27 +234,51 @@ def _test_func( ...@@ -143,27 +234,51 @@ def _test_func(
if fsdp_config: if fsdp_config:
ddp = False ddp = False
assert isinstance(fsdp_config, dict), str(fsdp_config) assert isinstance(fsdp_config, dict), str(fsdp_config)
if fsdp_config["mixed_precision"]:
# To match DDP in AMP -O1, we need fp32 reduce scatter.
fsdp_config["fp32_reduce_scatter"] = True
model = Model() model = Model()
model.load_state_dict(state_before) model.load_state_dict(state_before)
model = model.cuda() model = model.cuda()
class DummyScaler:
def scale(self, loss):
return loss
def step(self, optim):
optim.step()
def update(self):
pass
scaler = DummyScaler()
if ddp: if ddp:
model = SyncBatchNorm.convert_sync_batchnorm(model) model = SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank], broadcast_buffers=True)
if ddp_mixed_precision:
scaler = GradScaler()
else: else:
# Note, different rank may wrap in different order due to different random # Note, different rank may wrap in different order due to different random
# seeds. But results should be the same. # seeds. But results should be the same.
if random.randint(0, 1) == 0: if random.randint(0, 1) == 0:
print("auto_wrap_bn, then convert_sync_batchnorm") print(f"auto_wrap_bn {fsdp_wrap_bn}, then convert_sync_batchnorm")
model = auto_wrap_bn(model) if fsdp_wrap_bn:
model = SyncBatchNorm.convert_sync_batchnorm(model) model = auto_wrap_bn(model, _single_rank_pg)
model = _bn_converter(model)
else: else:
print("convert_sync_batchnorm, then auto_wrap_bn") print(f"convert_sync_batchnorm, then auto_wrap_bn {fsdp_wrap_bn}")
model = SyncBatchNorm.convert_sync_batchnorm(model) model = _bn_converter(model)
model = auto_wrap_bn(model) if fsdp_wrap_bn:
model = auto_wrap_bn(model, _single_rank_pg)
model = FSDP(model, **fsdp_config).cuda() model = FSDP(model, **fsdp_config).cuda()
if fsdp_config["mixed_precision"]:
scaler = ShardedGradScaler()
# Print the model for verification.
if rank == 0:
print(model)
optim = SGD(model.parameters(), lr=0.1) optim = SGD(model.parameters(), lr=0.1)
loss_func = CrossEntropyLoss()
for in_data in inputs[rank]: for in_data in inputs[rank]:
in_data = in_data.cuda() in_data = in_data.cuda()
...@@ -171,11 +286,15 @@ def _test_func( ...@@ -171,11 +286,15 @@ def _test_func(
if ddp and ddp_mixed_precision: if ddp and ddp_mixed_precision:
in_data = in_data.half() in_data = in_data.half()
context = torch.cuda.amp.autocast(enabled=True) context = torch.cuda.amp.autocast(enabled=True)
if not ddp and fsdp_config["mixed_precision"]:
context = torch.cuda.amp.autocast(enabled=True)
with context: with context:
out = model(in_data) out = model(in_data)
loss = out.sum() fake_label = torch.zeros(1, dtype=torch.long).cuda()
loss.backward() loss = loss_func(out.unsqueeze(0), fake_label)
optim.step() scaler.scale(loss).backward()
scaler.step(optim)
scaler.update()
optim.zero_grad() optim.zero_grad()
if ddp: if ddp:
...@@ -190,6 +309,15 @@ def _test_func( ...@@ -190,6 +309,15 @@ def _test_func(
# Move tensors to CPU to compare numerics. # Move tensors to CPU to compare numerics.
for k, v in fsdp_state.items(): for k, v in fsdp_state.items():
fsdp_state[k] = v.cpu() fsdp_state[k] = v.cpu()
# Change False to True to enable this when you want to debug the mismatch.
if False and rank == 0:
def dump(d):
for k, v in d.items():
print(k, v)
dump(state_after)
dump(fsdp_state)
assert objects_are_equal(state_after, fsdp_state, raise_exception=True) assert objects_are_equal(state_after, fsdp_state, raise_exception=True)
teardown() teardown()
...@@ -215,10 +343,32 @@ def test1(temp_files, ddp_ref, precision, flatten): ...@@ -215,10 +343,32 @@ def test1(temp_files, ddp_ref, precision, flatten):
fsdp_config["mixed_precision"] = precision == "mixed" fsdp_config["mixed_precision"] = precision == "mixed"
fsdp_config["flatten_parameters"] = flatten == "flatten" fsdp_config["flatten_parameters"] = flatten == "flatten"
world_size = 2 if fsdp_config["mixed_precision"] and torch_cuda_version() < (11, 0):
pytest.skip("Only CUDA 11 is supported with AMP equivalency")
# Wrap BN half of the time in full precision mode.
wrap_bn = True
if random.randint(0, 1) == 0:
wrap_bn = False
# Always wrap BN in mixed precision mode.
if fsdp_config["mixed_precision"]:
wrap_bn = True
world_size = _world_size
mp.spawn( mp.spawn(
_test_func, _test_func,
args=(world_size, fsdp_config, None, temp_files[0], temp_files[1], state_before, inputs, None, state_after), args=(
world_size,
fsdp_config,
wrap_bn,
None,
temp_files[0],
temp_files[1],
state_before,
inputs,
None,
state_after,
),
nprocs=world_size, nprocs=world_size,
join=True, join=True,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment