Unverified Commit a0458b98 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] FSDP: disable single rank process group for auto_wrap_bn and fixed mixed...

[fix] FSDP: disable single rank process group for auto_wrap_bn and fixed mixed precision regnet test (#556)

* [fix] disable single rank process group for auto_wrap_bn

- beefed up unit test with regnet-like model
- found that single-rank process group is causing problem
- disabled it to enable convergence tests on the vissl side
- use `raise e from None` to get a better assertion output
  in testing.py.

* [test] fix regnet test for ddp+mixed_precision

- need AMP context in FSDP
- workaround different between ddp & fsdp when bias=True
- fixed a bug in input data generation that caused different ranks have
  the same data with wrong iteration count.
- added TODO for need a better loss and grad_scaler and reduced
  iters so there is no nan.
- added a (disabled) debugging code

* lint

* lint

* add scaler

* lint

* scaler

* add a real loss

* seeding in the ranks

* blance tests

* run AMP DDP==FSDP test only on cuda version 11 and up

* add relu inplace and comment

* make wrap_bn covers more cases in full precision mode
parent acb9ef00
......@@ -1517,7 +1517,7 @@ def _pre_load_state_dict_hook(
########################################################################################
def auto_wrap_bn(module: nn.Module) -> nn.Module:
def auto_wrap_bn(module: nn.Module, single_rank_pg: bool = False) -> nn.Module:
"""
Auto wrap all BatchNorm (BN) instances with a safer FSDP, esp. when convert
to sync BN is used and the outer FSDP is flattening.
......@@ -1531,6 +1531,9 @@ def auto_wrap_bn(module: nn.Module) -> nn.Module:
Args:
module (nn.Module):
The model (or part of the model) in which BN to be pre-wrapped.
single_rank_pg (bool):
If true, put BNs in a single-rank process group. Default False.
This might be needed for Apex sync BN support. Still under construction.
Returns:
Processed module, where BNs are wrapped with a special FSDP instance.
......@@ -1543,10 +1546,15 @@ def auto_wrap_bn(module: nn.Module) -> nn.Module:
else:
return is_bn and not isinstance(module, tuple(default_auto_wrap_policy.EXCLUDE_WRAP_MODULES)) # type: ignore
my_rank = dist.get_rank()
pg = None
if single_rank_pg:
# No sharding with this single member group.
my_rank = dist.get_rank()
pg = dist.new_group(ranks=[my_rank])
fsdp_config = {
"wrapper_cls": FullyShardedDataParallel,
"process_group": dist.new_group(ranks=[my_rank]), # No sharding with this single member group.
"process_group": pg,
"mixed_precision": False, # Keep the weights in FP32.
"flatten_parameters": False, # Do not flatten.
}
......
......@@ -32,6 +32,7 @@ import logging
import multiprocessing
import os
import random
import subprocess
import sys
import tempfile
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
......@@ -116,6 +117,30 @@ def torch_version() -> Tuple[int, ...]:
return tuple(int(n) for n in numbering)
_smi_ver = None
def torch_cuda_version(compiled: bool = False) -> Tuple[int, ...]:
if compiled:
numbering = torch.version.cuda.split(".")[:2]
else:
global _smi_ver
if _smi_ver is None:
def get_smi_ver() -> str:
"""Get CUDA version from nvidia-smi"""
for line in subprocess.check_output("nvidia-smi".split()).decode("utf-8").split("\n"):
if "CUDA Version" in line:
res = line.split()[8]
assert res.startswith("10.") or res.startswith("11."), res
return res
assert False
_smi_ver = get_smi_ver()
numbering = _smi_ver.split(".")[:2]
return tuple(int(n) for n in numbering)
def dist_init(rank: int, world_size: int, filename: str, filename_rpc: str = "") -> bool:
"""
Initialize torch distributed, based on a temporary file shared across ranks, which makes it possible for unrelated
......@@ -445,7 +470,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: O
# Add dict key to the assertion error.
msg = e.args[0]
new_msg = f"For dict key '{dict_key}': {msg}"
raise AssertionError(new_msg)
raise AssertionError(new_msg) from None
else:
raise e
else:
......
......@@ -33,4 +33,8 @@ tests/nn/pipe/test_deferred_batch_norm.py
tests/nn/pipe/test_dependency.py
tests/nn/pipe/test_stream.py
tests/experimental/nn/test_multiprocess_pipe.py
tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_top2gating.py
tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
tests/experimental/nn/test_offload.py
tests/nn/data_parallel/test_fsdp_apply.py
tests/nn/data_parallel/test_fsdp_regnet.py
tests/nn/data_parallel/test_fsdp_uneven.py
tests/nn/data_parallel/test_fsdp_grad_scaler.py
tests/nn/data_parallel/test_fsdp_no_sync.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_input.py
tests/nn/data_parallel/test_fsdp_multiple_forward.py
tests/nn/data_parallel/test_fsdp_regnet.py
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
......@@ -17,12 +17,8 @@ tests/nn/pipe/skip/test_portal.py
tests/nn/pipe/skip/test_tracker.py
tests/nn/pipe/skip/test_inspect_skip_layout.py
tests/nn/pipe/test_checkpoint_ddp.py
tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_top2gating.py
tests/optim/test_single_node_adascale.py
tests/optim/test_adam.py
tests/optim/test_oss.py
tests/optim/test_oss_adascale.py
tests/optim/test_ddp_adascale.py
tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
tests/experimental/nn/test_offload.py
......@@ -15,13 +15,26 @@ import tempfile
import pytest
import torch
from torch.cuda.amp import GradScaler
import torch.multiprocessing as mp
from torch.nn import BatchNorm2d, Conv2d, Module, SyncBatchNorm
from torch.nn import (
AdaptiveAvgPool2d,
BatchNorm2d,
Conv2d,
CrossEntropyLoss,
Linear,
Module,
ReLU,
Sequential,
Sigmoid,
SyncBatchNorm,
)
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState, auto_wrap_bn
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils.testing import (
dist_init,
objects_are_equal,
......@@ -29,21 +42,96 @@ from fairscale.utils.testing import (
skip_if_single_gpu,
state_dict_norm,
teardown,
torch_cuda_version,
torch_version,
)
# Const test params.
# Reduce iterations to 1 for debugging.
# Change world_size to 8 on beefy machines for better test coverage.
_world_size = 2
_iterations = 5
# Cover different ReLU flavor. This will cause DDP and FSDP models to have
# different ReLUs since they will different random flags.
_relu_inplace = True
if random.randint(0, 1) == 0:
_relu_inplace = False
# TODO (Min): test apex BN when available in the future.
try:
import apex
apex_bn_converter = apex.parallel.convert_syncbn_model
except ImportError:
apex_bn_converter = None
pytorch_bn_converter = SyncBatchNorm.convert_sync_batchnorm # type: ignore
_bn_converter = pytorch_bn_converter
_single_rank_pg = False
class ResBlock(Module):
"""Conv block in regnet with residual connection."""
def __init__(self, width_in, width_out):
super().__init__()
self.proj = Conv2d(width_in, width_out, (1, 1), (2, 2), bias=False)
self.bn = BatchNorm2d(width_out)
self.f = Sequential(
Sequential( # block a
Conv2d(width_in, width_out, (1, 1), (1, 1), bias=False), BatchNorm2d(width_out), ReLU(_relu_inplace),
),
Sequential( # block b
Conv2d(width_out, width_out, (3, 3), (2, 2), (1, 1), groups=2, bias=False),
BatchNorm2d(width_out),
ReLU(_relu_inplace),
),
Sequential( # block se
AdaptiveAvgPool2d((1, 1)),
Sequential(
Conv2d(width_out, 2, (1, 1), (1, 1), bias=False),
ReLU(_relu_inplace),
Conv2d(2, width_out, (1, 1), (1, 1), bias=False),
Sigmoid(),
),
),
Conv2d(width_out, width_out, (1, 1), (1, 1), bias=False), # block c
BatchNorm2d(width_out), # final_bn
)
self.relu = ReLU()
self.need_fsdp_wrap = True
def forward(self, x):
x = self.bn(self.proj(x)) + self.f(x)
return self.relu(x)
class Model(Module):
"""SSL model with trunk and head."""
def __init__(self):
super().__init__()
# TODO (Min): for now, we just test pytorch sync_bn here.
# this will grow into regnet; testing apex sync_bn, etc.
self.conv = Conv2d(2, 2, (1, 1))
self.bn = BatchNorm2d(2)
print(f"Using relu inplace: {_relu_inplace}")
self.trunk = Sequential()
self.trunk.need_fsdp_wrap = True # Set a flag for later wrapping.
stem = Sequential(Conv2d(2, 4, (3, 3), (2, 2), (1, 1), bias=False), BatchNorm2d(4), ReLU(_relu_inplace))
any_stage_block1_0 = ResBlock(4, 8)
self.trunk.add_module("stem", stem)
self.trunk.add_module("any_stage_block1", Sequential(any_stage_block1_0))
self.head = Sequential(
# TODO (Min): FSDP-mixed_precision doesn't compute the same ways as DDP AMP when bias=True.
# so, we use bias=False for now in the projection_head.
# The Conv2d layers above does not use bias in regnet, but even if they use
# bias, FSDP and DDP seem to agree on how it is computed.
Sequential(Linear(16, 16, bias=False), ReLU(), Linear(16, 8, bias=False),), # projection_head
Linear(8, 15, bias=False), # prototypes0
)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.trunk(x).reshape(-1)
x = self.head(x)
return x
......@@ -67,9 +155,10 @@ def ddp_ref():
state_before = model.state_dict()
# Get reference inputs per rank.
world_size = 2
iterations = 100
inputs = [[]] * world_size
world_size = _world_size
iterations = _iterations
print(f"Getting DDP reference for world_size {world_size} and iterations {iterations}")
inputs = [[] for i in range(world_size)]
for rank in range(world_size):
for i in range(iterations):
inputs[rank].append(torch.rand(2, 2, 2, 2))
......@@ -86,6 +175,7 @@ def ddp_ref():
args=(
world_size,
fsdp_config,
None,
precision == "mixed",
temp_file_name,
unused,
......@@ -128,6 +218,7 @@ def _test_func(
rank,
world_size,
fsdp_config,
fsdp_wrap_bn,
ddp_mixed_precision,
tempfile_name,
unused,
......@@ -143,27 +234,51 @@ def _test_func(
if fsdp_config:
ddp = False
assert isinstance(fsdp_config, dict), str(fsdp_config)
if fsdp_config["mixed_precision"]:
# To match DDP in AMP -O1, we need fp32 reduce scatter.
fsdp_config["fp32_reduce_scatter"] = True
model = Model()
model.load_state_dict(state_before)
model = model.cuda()
class DummyScaler:
def scale(self, loss):
return loss
def step(self, optim):
optim.step()
def update(self):
pass
scaler = DummyScaler()
if ddp:
model = SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[rank])
model = DDP(model, device_ids=[rank], broadcast_buffers=True)
if ddp_mixed_precision:
scaler = GradScaler()
else:
# Note, different rank may wrap in different order due to different random
# seeds. But results should be the same.
if random.randint(0, 1) == 0:
print("auto_wrap_bn, then convert_sync_batchnorm")
model = auto_wrap_bn(model)
model = SyncBatchNorm.convert_sync_batchnorm(model)
print(f"auto_wrap_bn {fsdp_wrap_bn}, then convert_sync_batchnorm")
if fsdp_wrap_bn:
model = auto_wrap_bn(model, _single_rank_pg)
model = _bn_converter(model)
else:
print("convert_sync_batchnorm, then auto_wrap_bn")
model = SyncBatchNorm.convert_sync_batchnorm(model)
model = auto_wrap_bn(model)
print(f"convert_sync_batchnorm, then auto_wrap_bn {fsdp_wrap_bn}")
model = _bn_converter(model)
if fsdp_wrap_bn:
model = auto_wrap_bn(model, _single_rank_pg)
model = FSDP(model, **fsdp_config).cuda()
if fsdp_config["mixed_precision"]:
scaler = ShardedGradScaler()
# Print the model for verification.
if rank == 0:
print(model)
optim = SGD(model.parameters(), lr=0.1)
loss_func = CrossEntropyLoss()
for in_data in inputs[rank]:
in_data = in_data.cuda()
......@@ -171,11 +286,15 @@ def _test_func(
if ddp and ddp_mixed_precision:
in_data = in_data.half()
context = torch.cuda.amp.autocast(enabled=True)
if not ddp and fsdp_config["mixed_precision"]:
context = torch.cuda.amp.autocast(enabled=True)
with context:
out = model(in_data)
loss = out.sum()
loss.backward()
optim.step()
fake_label = torch.zeros(1, dtype=torch.long).cuda()
loss = loss_func(out.unsqueeze(0), fake_label)
scaler.scale(loss).backward()
scaler.step(optim)
scaler.update()
optim.zero_grad()
if ddp:
......@@ -190,6 +309,15 @@ def _test_func(
# Move tensors to CPU to compare numerics.
for k, v in fsdp_state.items():
fsdp_state[k] = v.cpu()
# Change False to True to enable this when you want to debug the mismatch.
if False and rank == 0:
def dump(d):
for k, v in d.items():
print(k, v)
dump(state_after)
dump(fsdp_state)
assert objects_are_equal(state_after, fsdp_state, raise_exception=True)
teardown()
......@@ -215,10 +343,32 @@ def test1(temp_files, ddp_ref, precision, flatten):
fsdp_config["mixed_precision"] = precision == "mixed"
fsdp_config["flatten_parameters"] = flatten == "flatten"
world_size = 2
if fsdp_config["mixed_precision"] and torch_cuda_version() < (11, 0):
pytest.skip("Only CUDA 11 is supported with AMP equivalency")
# Wrap BN half of the time in full precision mode.
wrap_bn = True
if random.randint(0, 1) == 0:
wrap_bn = False
# Always wrap BN in mixed precision mode.
if fsdp_config["mixed_precision"]:
wrap_bn = True
world_size = _world_size
mp.spawn(
_test_func,
args=(world_size, fsdp_config, None, temp_files[0], temp_files[1], state_before, inputs, None, state_after),
args=(
world_size,
fsdp_config,
wrap_bn,
None,
temp_files[0],
temp_files[1],
state_before,
inputs,
None,
state_after,
),
nprocs=world_size,
join=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment