Unverified Commit 175fdeb0 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feature] Unit test with and without buckets for all ShardedDDP unit tests (#400)

* test with and without buckets for all the shardedDDP unit tests
* parametrize all the things
* refactoring, adding even more  combinations at times
* handle hosts not having cuda
parent 4396ef4a
......@@ -51,9 +51,9 @@ class ShardedDataParallel(nn.Module):
Synchronize the models in between the ranks when starting up. Not needed if each rank has the same seed,
or the training restarts from a saved state
reduce_buffer_size (int):
The max size of the buffer used to batch the small parameter tensors, in number of elements (default 8M).
The max size of the buffer used to batch the small parameter tensors, in number of elements (default 0 - unused).
this will impact the long term memory consumption, because these buckets correspond to parameters which will not be sharded.
Set to 0 to remove all bucketing.
Set to 0 to remove all bucketing, 1M to 8M is usually reasonable.
auto_refresh_trainable (bool):
(default: True) Check whether the parameters trainability (`requires_grad`) has changed and update both ShardedDDP
and OSS automatically if this is the case. If set to False, `refresh_trainable()` needs to be called anytime
......@@ -98,7 +98,7 @@ class ShardedDataParallel(nn.Module):
process_group: Any = None,
broadcast_buffers: bool = True,
sync_models_at_startup: bool = True,
reduce_buffer_size: int = 2 ** 23,
reduce_buffer_size: int = 0,
auto_refresh_trainable: bool = True,
):
super().__init__()
......@@ -111,6 +111,7 @@ class ShardedDataParallel(nn.Module):
# Handle a no_sync() context which prevents the gradient synchronization,
# accumulate in place
self.should_accumulate_grads = False
self.accumulate_grads_flipped = False
# Communication related attributes
self.process_group = process_group if process_group is not None else dist.group.WORLD
......@@ -153,10 +154,6 @@ class ShardedDataParallel(nn.Module):
# - setup buckets and tensor views
model_size = sum([p.numel() for p in self.module.parameters()])
if dist.get_world_size(self.process_group) <= 8:
logging.info("Assuming single node environment. De-activating ShardedDDP buckets")
reduce_buffer_size = 0
self.buffer_max_size = min(reduce_buffer_size, model_size)
logging.info(
"ShardedDDP bucket size: {:.2f}M parameters, model size {:.2f}M parameters".format(
......@@ -230,6 +227,11 @@ class ShardedDataParallel(nn.Module):
.. note::
This method modifies the module in-place.
.. warning:
Device changes are not supported, and this will raise an exception. The issue in that case is not
really ShardedDDP, but OSS which will not be aware of the device change, and whose buffers will be
in a broken state.
Arguments:
device (:class:`torch.device`): the desired device of the parameters and buffers in this module.
dtype (:class:`torch.dtype`): the desired floating point type of the floating point parameters and buffers.
......@@ -237,14 +239,18 @@ class ShardedDataParallel(nn.Module):
Returns:
Module: self.
"""
for device in self.buckets.keys():
for bucket in self.buckets[device]:
assert device in self.buckets.keys(), "Changing devices is not supported, because this would break OSSs state"
assert (
len(self.buckets.keys()) == 1
), "Several devices specified to begin with, incompatible with setting a single device here"
for _device in self.buckets.keys():
for bucket in self.buckets[_device]:
bucket.buffer.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.module.to(device)
self.module.to(device=device, dtype=dtype, non_blocking=non_blocking)
def refresh_trainable(self) -> None:
""" If the module trainability has changed, update all the assumptions """
......@@ -320,7 +326,7 @@ class ShardedDataParallel(nn.Module):
See :meth:`torch.optim.Optimizer.zero_grad` for details.
"""
for index, trainable_param in enumerate(self._trainable_params):
for index, trainable_param in enumerate(self._all_params):
if set_to_none and not self._should_bucket_grad[index]:
trainable_param.grad = None
elif trainable_param.grad is not None:
......@@ -339,6 +345,7 @@ class ShardedDataParallel(nn.Module):
old_should_accumulate_grads = self.should_accumulate_grads
self.should_accumulate_grads = True
yield
self.accumulate_grads_flipped = self.should_accumulate_grads != old_should_accumulate_grads
self.should_accumulate_grads = old_should_accumulate_grads
@torch.no_grad()
......@@ -352,13 +359,19 @@ class ShardedDataParallel(nn.Module):
assert self._bucket_list is not None
for bucket in self._bucket_list:
assert not self.training or self.should_accumulate_grads or bucket.sent, (
"A bucket failed to be sent, probably unused parameters."
+ "Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-"
assert (
self.accumulate_grads_flipped or not self.training or self.should_accumulate_grads or bucket.sent
), (
"A bucket failed to be sent, probably unused parameters. "
+ "Either mark the unused parameter as not trainable (`.requires_grad = False`) "
+ "or de-activate ShardedDDP buckets -set `reduce_buffer_size` to 0-"
)
bucket.reset()
if not self.should_accumulate_grads:
self.accumulate_grads_flipped = False
def _find_rank(self, param: Parameter) -> Tuple[OSS, int]:
""" Look up where this parameter belongs to """
for optim in self.sharded_optimizers:
......@@ -394,10 +407,12 @@ class ShardedDataParallel(nn.Module):
param.grad = None
# Async reduce for this buffer, log the future
dst_global_rank = OSS.get_global_rank(self.process_group, dst_rank)
self._work_handles.append(
Workhandle(
handle=dist.reduce(
tensor=param.grad.data, dst=dst_rank, group=self.process_group, async_op=True
tensor=param.grad.data, dst=dst_global_rank, group=self.process_group, async_op=True
),
callback=cleanup,
)
......@@ -435,7 +450,10 @@ class ShardedDataParallel(nn.Module):
self._work_handles.append(
Workhandle(
handle=dist.reduce(
tensor=bucket.buffer, dst=dst_rank, group=self.process_group, async_op=True,
tensor=bucket.buffer,
dst=bucket.destination,
group=self.process_group,
async_op=True,
),
callback=None,
)
......@@ -470,33 +488,11 @@ class ShardedDataParallel(nn.Module):
p_tmp = param.expand_as(param)
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0]
dst_rank = OSS.get_global_rank(self.process_group, self._trainable_param_to_rank[param])
dst_rank = self._trainable_param_to_rank[param]
grad_acc.register_hook(self._get_reduce_fn(index, param, dst_rank))
self._grad_accs.append(grad_acc) # keep this function in scope
# Add a hook on the module to flush the buckets, if needed
if self.use_buckets:
def bucket_flush(*_: Any) -> None:
assert self._bucket_list is not None
handle = None
for bucket in self._bucket_list:
if not bucket.sent:
# Reduce the bucket. Some parameters went unused and this bucket was not flushed
bucket.buffer.mul_(self.world_size_scaling)
bucket.sent = True
handle = dist.reduce(
tensor=bucket.buffer, dst=bucket.destination, group=self.process_group, async_op=True,
)
# Only wait on the last handle
if handle:
handle.wait()
self.module.register_backward_hook(bucket_flush)
@torch.no_grad()
def _sync_params_and_buffers(self) -> None:
"""
......@@ -545,7 +541,7 @@ class ShardedDataParallel(nn.Module):
for param in self._trainable_params:
device = param.device
dst_rank = OSS.get_global_rank(self.process_group, self._trainable_param_to_rank[param])
dst_rank = self._trainable_param_to_rank[param]
if param.device not in self.buckets.keys():
self.buckets[param.device] = [
......@@ -554,7 +550,7 @@ class ShardedDataParallel(nn.Module):
]
bucket = self.buckets[device][dst_rank]
bucket.destination = dst_rank
bucket.destination = OSS.get_global_rank(self.process_group, dst_rank)
# Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket
......
......@@ -412,7 +412,7 @@ class OSS(Optimizer):
def refresh_trainable(self) -> None:
""" Updates the partitioning and communication patterns if the trainability (`requires_grad`)
of some parameters changed
of some parameters changed.
"""
# Create the optim which will work on the param shard
......
......@@ -54,7 +54,7 @@ skip_if_single_gpu = pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason="multiple GPUs required"
)
skip_if_less_four_gpu = pytest.mark.skipif(
skip_if_less_than_four_gpu = pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason="4 GPUs or more required"
)
......@@ -67,6 +67,11 @@ skip_if_py39_no_cuda = pytest.mark.skipif(
reason="Python3.9 wo CUDA is skipped",
)
available_devices = ["cpu"]
if torch.cuda.is_available():
available_devices.append("cuda")
_, filename_mpi = tempfile.mkstemp()
......@@ -418,3 +423,31 @@ def check_same_model_params(model_a: torch.nn.Module, model_b: torch.nn.Module,
for b_a, b_b in zip(model_a.buffers(), model_b.buffers()):
assert torch.allclose(b_a, b_b), f"Model buffers differ {b_a} - {b_b}\n" + message
def check_same_models_across_ranks(
model: torch.nn.Module, process_group: Any, params_should_be_equal: bool, check_broadcast_buffers: bool
) -> None:
world_size = dist.get_world_size(process_group)
rank = dist.get_rank(process_group)
for param in model.parameters():
# collect the params across the rank
receptacle = [param.clone() for _ in range(world_size)]
dist.all_gather(receptacle, param, group=process_group)
if rank == 0:
for sync_p in receptacle[1:]:
assert not params_should_be_equal or torch.all(
torch.eq(receptacle[0], sync_p)
), "Models differ in between ranks"
# Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
if check_broadcast_buffers:
for buffer in model.buffers():
receptacle = [buffer.clone() for _ in range(world_size)]
dist.all_gather(receptacle, buffer, group=process_group)
if rank == 0:
for sync_b in receptacle[1:]:
assert not params_should_be_equal or torch.all(
torch.eq(receptacle[0], sync_b)
), "Models differ in between ranks"
......@@ -10,9 +10,9 @@ Testing ShardedDDP
from contextlib import suppress
import copy
import tempfile
from typing import List
import numpy as np
import pytest
import torch
from torch.cuda.amp import GradScaler as TorchGradScaler
import torch.distributed as dist
......@@ -25,126 +25,126 @@ from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils.testing import (
GPT2,
available_devices,
check_same_model_params,
skip_if_less_four_gpu,
check_same_models_across_ranks,
skip_if_less_than_four_gpu,
skip_if_no_cuda,
skip_if_py38,
skip_if_single_gpu,
)
def run_one_step(rank, world_size, backend, device, temp_file_name):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
def _get_mlp():
return Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
class _DoubleInput(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = _get_mlp()
def forward(self, x, y):
x1 = self.mlp(x)
x2 = self.mlp(y)
return torch.cat((x1, x2), dim=1)
def run_one_step(
rank, world_size, backend, device, temp_file_name, broadcast_buffers, grad_accumulation, reduce_buffer_size,
):
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
if device == torch.device("cuda"):
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
def check(broadcast_buffers: bool, grad_accumulation: bool = False) -> None:
# Any model works. Add one different buffer per rank
model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
# Any model works. Add one different buffer per rank
model = _get_mlp()
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
next(model.parameters()).requires_grad = False # Test non-trainable parameters
next(model.parameters()).requires_grad = False # Test non-trainable parameters
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(
model, optimizer, broadcast_buffers=broadcast_buffers, reduce_buffer_size=reduce_buffer_size
)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer, broadcast_buffers=broadcast_buffers)
# The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
check_same_models_across_ranks(
ddp_model, dist.group.WORLD, params_should_be_equal=True, check_broadcast_buffers=broadcast_buffers
)
def check_same_model_params(same_params: bool):
# Check that all the params are the same on all ranks
# This should be true with and without broadcast_buffers, we don't have any real buffer here
receptacle: List[torch.Tensor] = []
if dist.get_backend() != "nccl":
for pg in optimizer.param_groups:
for p in pg["params"]:
# Check the params
receptacle = [p.clone() for _ in range(world_size)] if rank == 0 else []
dist.gather(p, receptacle, dst=0)
if rank == 0:
for sync_p in receptacle[1:]:
if same_params:
assert torch.all(torch.eq(receptacle[0], sync_p)), "Models differ in between ranks"
else:
assert not torch.all(
torch.eq(receptacle[0], sync_p)
), "Gradients should not have been synced"
# Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
if broadcast_buffers:
for b in ddp_model.buffers():
receptacle = [b.clone() for _ in range(world_size)] if rank == 0 else []
dist.gather(b, receptacle, dst=0)
if rank == 0:
for sync_b in receptacle[1:]:
if same_params:
assert torch.all(torch.eq(receptacle[0], sync_b)), "Models differ in between ranks"
else:
assert not torch.all(
torch.eq(receptacle[0], sync_b)
), "Gradients should not have been synced"
assert b.cpu().item() == 0.0
# The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
check_same_model_params(same_params=True)
# Optim loop
def closure():
optimizer.zero_grad()
with ddp_model.no_sync() if grad_accumulation else suppress():
input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor).abs().sum()
loss.backward()
return loss
# Optim loop
def closure():
optimizer.zero_grad()
with ddp_model.no_sync() if grad_accumulation else suppress():
input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor).abs().sum()
loss.backward()
return loss
# The models should stay the same in between the ranks
for i in range(5):
_ = optimizer.step(closure=closure)
# when running on cpu/gloo the "nodes" are not really different
same_params = device == torch.device("cpu") or grad_accumulation
check_same_models_across_ranks(
ddp_model, dist.group.WORLD, params_should_be_equal=same_params, check_broadcast_buffers=broadcast_buffers
)
# The models should stay the same in between the ranks
for i in range(5):
_ = optimizer.step(closure=closure)
# when running on cpu/gloo the "nodes" are not really different
same_params = device == torch.device("cpu") or grad_accumulation
check_same_model_params(same_params=same_params)
check(broadcast_buffers=False)
check(broadcast_buffers=True)
check(broadcast_buffers=False, grad_accumulation=True)
check(broadcast_buffers=True, grad_accumulation=True)
dist.destroy_process_group()
def run_test(backend, device, world_size=2):
def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size):
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_one_step, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
mp.spawn(
run_one_step,
args=(world_size, backend, device, temp_file_name, broadcast_buffers, grad_accumulation, reduce_buffer_size),
nprocs=world_size,
join=True,
)
@skip_if_no_cuda
@skip_if_single_gpu
def test_step_gpu():
run_test(backend=dist.Backend.NCCL, device=torch.device("cuda"))
@pytest.mark.parametrize("broadcast_buffers", [True, False])
@pytest.mark.parametrize("grad_accumulation", [True, False])
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_step_gpu(broadcast_buffers, grad_accumulation, reduce_buffer_size):
world_size = 2
run_test(
dist.Backend.NCCL, torch.device("cuda"), world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size
)
@skip_if_py38
def test_step_cpu():
run_test(backend=dist.Backend.GLOO, device=torch.device("cpu"))
@pytest.mark.parametrize("broadcast_buffers", [True, False])
@pytest.mark.parametrize("grad_accumulation", [True, False])
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_step_cpu(broadcast_buffers, grad_accumulation, reduce_buffer_size):
world_size = 2
run_test(
dist.Backend.GLOO, torch.device("cpu"), world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size
)
def run_ddp_parity(rank, world_size, backend, temp_file_name):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
def run_ddp_parity(
rank, world_size, backend, temp_file_name, reduce_buffer_size, grad_accumulation, change_train_graph
):
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
device = torch.device("cuda")
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
NUMBER_BATCHS = 5
INPUTS = 2
BATCH_SIZE = 32
BATCH_SIZE = 8
def check_parity(amp: bool, accumulate: bool, change_train_graph: bool, manual_reduction: bool):
def check_parity(amp: bool, manual_reduction: bool):
# The API should be the exact same in between the sharded and non-sharded variants, generic closure
def closure(model, scaler, input_tensor, should_accumulate, _manual_reduction=False):
......@@ -174,7 +174,7 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
model.reduce()
# Any model works. Add one different buffer per rank
model = Sequential(Linear(INPUTS, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
model = _get_mlp()
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
......@@ -182,13 +182,16 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
# properly reassigned when/if this changes
next(model.parameters()).requires_grad = False
sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-5, momentum=0.99)
sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-4, momentum=0.99)
sharded_ddp_model = ShardedDataParallel(
module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True
module=model,
sharded_optimizer=sharded_optimizer,
broadcast_buffers=True,
reduce_buffer_size=reduce_buffer_size,
)
ddp_model_single = copy.deepcopy(model)
ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-5, momentum=0.99)
ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-4, momentum=0.99)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True)
ddp_scaler = TorchGradScaler() if amp else None
......@@ -199,14 +202,18 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
# Typical training loop, check that we get the exact same results as DDP
for i in range(NUMBER_BATCHS):
input_tensor = torch.rand((BATCH_SIZE, INPUTS)).to(device)
input_tensor = torch.rand((BATCH_SIZE, 2)).to(device)
def closure_ddp(input_tensor=input_tensor):
return closure(ddp_model, ddp_scaler, input_tensor, accumulate)
return closure(ddp_model, ddp_scaler, input_tensor, grad_accumulation)
def closure_sharded(input_tensor=input_tensor):
return closure(
sharded_ddp_model, sharded_ddp_scaler, input_tensor, accumulate, _manual_reduction=manual_reduction
sharded_ddp_model,
sharded_ddp_scaler,
input_tensor,
grad_accumulation,
_manual_reduction=manual_reduction,
)
# Step/scale both
......@@ -234,77 +241,82 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
next(ddp_model.parameters()).requires_grad = not next(ddp_model.parameters()).requires_grad
check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke")
# Test all combinations: AMP, Accumulate, Change train graph
# Test all combinations: AMP, Accumulate, Change train graph, reduce buckets
amp_tests = [False]
if hasattr(torch.cuda.amp, "autocast"):
amp_tests.append(True)
for accumulate in [False, True]:
for change_train_graph in [False, True]:
manual_reductions = [False, True] if not accumulate and not change_train_graph else [False]
for manual_reduction in manual_reductions:
for amp in amp_tests:
print(
f"Checking configuration: accumulate {accumulate} - change train graph {change_train_graph} - amp {amp} - manual reduction {manual_reduction}"
)
check_parity(
amp=amp,
accumulate=accumulate,
change_train_graph=change_train_graph,
manual_reduction=manual_reduction,
)
manual_reductions = [False, True] if not grad_accumulation and not change_train_graph else [False]
for manual_reduction in manual_reductions:
for amp in amp_tests:
print(
f"Checking configuration: accumulate {grad_accumulation}"
+ f" - change train graph {change_train_graph}"
+ f" - amp {amp}"
+ f" - manual reduction {manual_reduction}"
+ f" - buffers {reduce_buffer_size}",
flush=True,
)
check_parity(
amp=amp, manual_reduction=manual_reduction,
)
dist.destroy_process_group()
@skip_if_no_cuda
@skip_if_single_gpu
def test_ddp_parity():
temp_file_name = tempfile.mkstemp()[1]
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
@pytest.mark.parametrize("grad_accumulation", [True, False])
@pytest.mark.parametrize("change_train_graph", [True, False])
def test_ddp_parity(reduce_buffer_size, grad_accumulation, change_train_graph):
world_size = torch.cuda.device_count()
backend = dist.Backend.NCCL
mp.spawn(run_ddp_parity, args=(world_size, backend, temp_file_name), nprocs=world_size, join=True)
mp.spawn(
run_ddp_parity,
args=(world_size, backend, tempfile.mkstemp()[1], reduce_buffer_size, grad_accumulation, change_train_graph),
nprocs=world_size,
join=True,
)
def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_buffer_size):
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
device = torch.device("cuda")
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank) # Any model works. Add one different buffer per rank
model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
BATCHS = 20
model = _get_mlp()
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
n_half_params = len(list(model.parameters())) // 2
optim_settings = {"lr": 1e-3, "momentum": 0.99}
sharded_optimizer = OSS(
params=list(model.parameters())[:n_half_params], optim=torch.optim.SGD, lr=1e-3, momentum=0.99
)
sharded_optimizer_2 = OSS(
params=list(model.parameters())[n_half_params:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99
)
sharded_optimizer = OSS(params=list(model.parameters())[:n_half_params], optim=torch.optim.SGD, **optim_settings)
sharded_optimizer_2 = OSS(params=list(model.parameters())[n_half_params:], optim=torch.optim.SGD, **optim_settings)
sharded_ddp_model = ShardedDataParallel(module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True)
sharded_ddp_model = ShardedDataParallel(
module=model,
sharded_optimizer=[sharded_optimizer, sharded_optimizer_2],
broadcast_buffers=True,
reduce_buffer_size=reduce_buffer_size,
)
ddp_model_single = copy.deepcopy(model)
ddp_optimizer = torch.optim.SGD(list(ddp_model_single.parameters())[:n_half_params], lr=1e-3, momentum=0.99)
ddp_optimizer_2 = torch.optim.SGD(list(ddp_model_single.parameters())[n_half_params:], lr=1e-3, momentum=0.99)
ddp_optimizer = torch.optim.SGD(list(ddp_model_single.parameters())[:n_half_params], **optim_settings)
ddp_optimizer_2 = torch.optim.SGD(list(ddp_model_single.parameters())[n_half_params:], **optim_settings)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)
def check_same_model_params():
for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
assert torch.allclose(
p, ddp_p, atol=1e-3
), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}"
for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
assert torch.allclose(b, ddp_b, atol=1e-3), "Model buffers differ in between DDP and ShardedDDP"
check_same_model_params() # The models should stay the same in between the ranks
check_same_model_params(
sharded_ddp_model,
ddp_model,
f"DDP parity two optim test failing. differing at startup, Buffers {reduce_buffer_size}",
)
for i in range(20):
for i in range(BATCHS):
input_tensor = torch.rand((64, 2)).to(device)
# Run DDP
......@@ -314,6 +326,7 @@ def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name):
ddp_loss.backward()
ddp_optimizer.step()
ddp_optimizer_2.step()
torch.cuda.synchronize(device)
# Run Sharded
sharded_optimizer.zero_grad()
......@@ -322,43 +335,40 @@ def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name):
sharded_loss.backward()
sharded_optimizer.step()
sharded_optimizer_2.step()
check_same_model_params()
torch.cuda.synchronize(device)
check_same_model_params(
sharded_ddp_model, ddp_model, f"DDP parity two optim test failing, step {i}, buffers {reduce_buffer_size}",
)
dist.destroy_process_group()
@skip_if_no_cuda
@skip_if_single_gpu
def test_ddp_parity_two_optim():
temp_file_name = tempfile.mkstemp()[1]
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_ddp_parity_two_optim(reduce_buffer_size):
world_size = 2
backend = dist.Backend.NCCL
mp.spawn(run_ddp_parity_two_optim, args=(world_size, backend, temp_file_name), nprocs=world_size, join=True)
mp.spawn(
run_ddp_parity_two_optim,
args=(world_size, backend, tempfile.mkstemp()[1], reduce_buffer_size),
nprocs=world_size,
join=True,
)
def run_test_two_inputs(rank, world_size, backend, device, temp_file_name):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
if device == torch.device("cuda"):
def run_test_two_inputs(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
if device == "cuda":
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
class _DoubleInput(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
def forward(self, x, y):
x1 = self.mlp(x)
x2 = self.mlp(y)
return torch.cat((x1, x2), dim=1)
model = _DoubleInput().to(device)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer, reduce_buffer_size=reduce_buffer_size)
# Optim loop
def closure():
......@@ -374,25 +384,32 @@ def run_test_two_inputs(rank, world_size, backend, device, temp_file_name):
dist.destroy_process_group()
def test_inputs():
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
@pytest.mark.parametrize("backend", ["gloo", "nccl"])
@pytest.mark.parametrize("device", available_devices)
def test_inputs(reduce_buffer_size, backend, device):
# Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size = 2
backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cpu"
mp.spawn(run_test_two_inputs, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
if backend == "nccl" and device == "cpu":
pytest.skip("Incompatible combination, or cuda not available")
return
mp.spawn(
run_test_two_inputs,
args=(world_size, backend, device, tempfile.mkstemp()[1], reduce_buffer_size),
nprocs=world_size,
join=True,
)
def test_ddp_attributes():
# Check that ShardedDDP exposes the same attributes as Pytorch's DDP
# - is multi_device_module
# - device_type
url = "file://" + tempfile.mkstemp()[1]
dist.init_process_group(init_method=url, backend="gloo", rank=0, world_size=1)
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
model = Sequential(Linear(2, 3), Linear(3, 3))
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
assert hasattr(ddp_model, "is_multi_device_module")
......@@ -402,14 +419,12 @@ def test_ddp_attributes():
def test_random_attributes():
# Check that ShardedDDP exposes the original module's attributes
url = "file://" + tempfile.mkstemp()[1]
dist.init_process_group(init_method=url, backend="gloo", rank=0, world_size=1)
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
model = Sequential(Linear(2, 3), Linear(3, 3))
model.banana = "sweet"
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
assert hasattr(ddp_model, "banana")
......@@ -418,41 +433,51 @@ def test_random_attributes():
dist.destroy_process_group()
def run_test_device_change(rank, world_size, backend, device, temp_file_name):
def run_test_device_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
# Check that the wrapped module can change devices
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
model = Sequential(Linear(2, 3), Linear(3, 3)).cpu()
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
ddp_model.to(device)
inputs = torch.rand((10, 2), device=device)
outputs = ddp_model(inputs) # assert if the module has not been changed properly
loss = outputs.norm().backward()
model = Sequential(Linear(2, 3), Linear(3, 3)).cpu() # not device on purpose, test changing it after the fact
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(
model, optimizer, sync_models_at_startup=False, reduce_buffer_size=reduce_buffer_size
)
try:
ddp_model.to(device)
assert False, "Changing devices should be caught and not supported"
except AssertionError:
pass
dist.destroy_process_group()
@skip_if_no_cuda
@skip_if_single_gpu
def test_device_change():
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_device_change(reduce_buffer_size):
# Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
world_size = 2
backend = "gloo"
backend = "nccl"
temp_file_name = tempfile.mkstemp()[1]
device = "cuda"
mp.spawn(run_test_device_change, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
mp.spawn(
run_test_device_change,
args=(world_size, backend, device, temp_file_name, reduce_buffer_size),
nprocs=world_size,
join=True,
)
def run_test_training_change(rank, world_size, backend, device, temp_file_name):
url = "file://" + temp_file_name
group = dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
def run_test_training_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
group = dist.init_process_group(
init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size
)
torch.cuda.set_device(rank)
model = Sequential(Linear(2, 3), Linear(3, 3)).to(device)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer, process_group=group)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer, process_group=group, reduce_buffer_size=reduce_buffer_size)
inputs = torch.rand((10, 2), device=device)
outputs = ddp_model(inputs) # assert if the module has not been changed properly
......@@ -465,23 +490,30 @@ def run_test_training_change(rank, world_size, backend, device, temp_file_name):
dist.destroy_process_group()
def test_training_change():
world_size = 8
backend = "gloo"
@skip_if_no_cuda
@skip_if_single_gpu
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_training_change(reduce_buffer_size):
world_size = 2
backend = "nccl"
temp_file_name = tempfile.mkstemp()[1]
device = "cpu"
mp.spawn(run_test_training_change, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
device = "cuda"
mp.spawn(
run_test_training_change,
args=(world_size, backend, device, temp_file_name, reduce_buffer_size),
nprocs=world_size,
join=True,
)
def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_name):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
model = Sequential(Linear(2, 3), torch.nn.BatchNorm1d(3), Linear(3, 3)).to(device)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.to(device) # in pytorch 1.5 syncBN switches to the default device/cpu
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
assert isinstance(model[1], torch.nn.SyncBatchNorm)
......@@ -504,29 +536,17 @@ def test_ddp_sync_batch_norm():
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
if device == torch.device("cuda"):
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
class _DoubleInput(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
def forward(self, x, y):
x1 = self.mlp(x)
x2 = self.mlp(y)
return torch.cat((x1, x2), dim=1)
model = _DoubleInput().to(device)
parameters = list(model.parameters())
optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=0.01, momentum=0.99)
optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=0.01, momentum=0.99)
optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2])
# Optim loop
......@@ -556,22 +576,21 @@ def test_two_optimizers():
def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
INPUT_DIM = 32
INPUT_DIM = 16
BACH_SIZE = 10
STEPS = 10
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
if device == torch.device("cuda"):
torch.cuda.set_device(rank)
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
model = GPT2(
embed_dim=512, num_heads=2, num_layers=24, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2
embed_dim=256, num_heads=2, num_layers=12, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2
).to(device)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer, reduce_buffer_size=0)
# Optim loop
def closure():
......@@ -600,7 +619,7 @@ def test_gpt2():
mp.spawn(run_test_gpt2, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
def run_test_multiple_groups(rank, world_size, tempfile_name, backend):
def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_buffer_size):
# Only work with the even ranks, to check that the global_rank indexing is properly used
dist.init_process_group(init_method="file://" + tempfile_name, backend=backend, rank=rank, world_size=world_size)
......@@ -635,17 +654,9 @@ def run_test_multiple_groups(rank, world_size, tempfile_name, backend):
_ = optimizer.step(closure=closure)
# Check that all the params are the same on all ranks
for pg in optimizer.param_groups:
for p in pg["params"]:
receptacle = [p.clone() for _ in sub_group_ranks]
dist.all_gather(receptacle, p, group=process_group)
if rank == 0:
for sync_p in receptacle[1:]:
assert torch.all(
torch.eq(receptacle[0], sync_p)
), "Models differ in between ranks {} - {}".format(
torch.norm(receptacle[0]), torch.norm(sync_p)
)
check_same_models_across_ranks(
model, process_group, params_should_be_equal=True, check_broadcast_buffers=True
)
if rank in sub_group_ranks:
# Model not-fitting in the broadcast bucket
......@@ -654,24 +665,25 @@ def run_test_multiple_groups(rank, world_size, tempfile_name, backend):
)
# With SGD, Momentum is required to get a state to shard
optimizer = OSS(model.parameters(), lr=0.1, momentum=0.99, group=process_group)
model = ShardedDataParallel(model, optimizer, process_group=process_group)
optimizer = OSS(model.parameters(), group=process_group, lr=1e-3, momentum=0.99)
model = ShardedDataParallel(
model, optimizer, process_group=process_group, reduce_buffer_size=reduce_buffer_size
)
check(optimizer, model)
dist.destroy_process_group(process_group)
@skip_if_less_four_gpu
def test_multiple_groups():
@skip_if_less_than_four_gpu
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
@pytest.mark.parametrize("backend", ["gloo", "nccl"])
def test_multiple_groups(reduce_buffer_size, backend):
world_size = 4
temp_file_name = tempfile.mkstemp()[1]
for backend in ["gloo", "nccl"]:
print("Testing backend ", backend)
mp.spawn(
run_test_multiple_groups, args=(world_size, temp_file_name, backend), nprocs=world_size, join=True,
)
mp.spawn(
run_test_multiple_groups, args=(world_size, temp_file_name, "gloo"), nprocs=world_size, join=True,
run_test_multiple_groups,
args=(world_size, temp_file_name, backend, reduce_buffer_size),
nprocs=world_size,
join=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment