"git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "65e5bb3ea1d28b248c44fa83ce6eccf719750c37"
Unverified Commit 175fdeb0 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feature] Unit test with and without buckets for all ShardedDDP unit tests (#400)

* test with and without buckets for all the shardedDDP unit tests
* parametrize all the things
* refactoring, adding even more  combinations at times
* handle hosts not having cuda
parent 4396ef4a
...@@ -51,9 +51,9 @@ class ShardedDataParallel(nn.Module): ...@@ -51,9 +51,9 @@ class ShardedDataParallel(nn.Module):
Synchronize the models in between the ranks when starting up. Not needed if each rank has the same seed, Synchronize the models in between the ranks when starting up. Not needed if each rank has the same seed,
or the training restarts from a saved state or the training restarts from a saved state
reduce_buffer_size (int): reduce_buffer_size (int):
The max size of the buffer used to batch the small parameter tensors, in number of elements (default 8M). The max size of the buffer used to batch the small parameter tensors, in number of elements (default 0 - unused).
this will impact the long term memory consumption, because these buckets correspond to parameters which will not be sharded. this will impact the long term memory consumption, because these buckets correspond to parameters which will not be sharded.
Set to 0 to remove all bucketing. Set to 0 to remove all bucketing, 1M to 8M is usually reasonable.
auto_refresh_trainable (bool): auto_refresh_trainable (bool):
(default: True) Check whether the parameters trainability (`requires_grad`) has changed and update both ShardedDDP (default: True) Check whether the parameters trainability (`requires_grad`) has changed and update both ShardedDDP
and OSS automatically if this is the case. If set to False, `refresh_trainable()` needs to be called anytime and OSS automatically if this is the case. If set to False, `refresh_trainable()` needs to be called anytime
...@@ -98,7 +98,7 @@ class ShardedDataParallel(nn.Module): ...@@ -98,7 +98,7 @@ class ShardedDataParallel(nn.Module):
process_group: Any = None, process_group: Any = None,
broadcast_buffers: bool = True, broadcast_buffers: bool = True,
sync_models_at_startup: bool = True, sync_models_at_startup: bool = True,
reduce_buffer_size: int = 2 ** 23, reduce_buffer_size: int = 0,
auto_refresh_trainable: bool = True, auto_refresh_trainable: bool = True,
): ):
super().__init__() super().__init__()
...@@ -111,6 +111,7 @@ class ShardedDataParallel(nn.Module): ...@@ -111,6 +111,7 @@ class ShardedDataParallel(nn.Module):
# Handle a no_sync() context which prevents the gradient synchronization, # Handle a no_sync() context which prevents the gradient synchronization,
# accumulate in place # accumulate in place
self.should_accumulate_grads = False self.should_accumulate_grads = False
self.accumulate_grads_flipped = False
# Communication related attributes # Communication related attributes
self.process_group = process_group if process_group is not None else dist.group.WORLD self.process_group = process_group if process_group is not None else dist.group.WORLD
...@@ -153,10 +154,6 @@ class ShardedDataParallel(nn.Module): ...@@ -153,10 +154,6 @@ class ShardedDataParallel(nn.Module):
# - setup buckets and tensor views # - setup buckets and tensor views
model_size = sum([p.numel() for p in self.module.parameters()]) model_size = sum([p.numel() for p in self.module.parameters()])
if dist.get_world_size(self.process_group) <= 8:
logging.info("Assuming single node environment. De-activating ShardedDDP buckets")
reduce_buffer_size = 0
self.buffer_max_size = min(reduce_buffer_size, model_size) self.buffer_max_size = min(reduce_buffer_size, model_size)
logging.info( logging.info(
"ShardedDDP bucket size: {:.2f}M parameters, model size {:.2f}M parameters".format( "ShardedDDP bucket size: {:.2f}M parameters, model size {:.2f}M parameters".format(
...@@ -230,6 +227,11 @@ class ShardedDataParallel(nn.Module): ...@@ -230,6 +227,11 @@ class ShardedDataParallel(nn.Module):
.. note:: .. note::
This method modifies the module in-place. This method modifies the module in-place.
.. warning:
Device changes are not supported, and this will raise an exception. The issue in that case is not
really ShardedDDP, but OSS which will not be aware of the device change, and whose buffers will be
in a broken state.
Arguments: Arguments:
device (:class:`torch.device`): the desired device of the parameters and buffers in this module. device (:class:`torch.device`): the desired device of the parameters and buffers in this module.
dtype (:class:`torch.dtype`): the desired floating point type of the floating point parameters and buffers. dtype (:class:`torch.dtype`): the desired floating point type of the floating point parameters and buffers.
...@@ -237,14 +239,18 @@ class ShardedDataParallel(nn.Module): ...@@ -237,14 +239,18 @@ class ShardedDataParallel(nn.Module):
Returns: Returns:
Module: self. Module: self.
""" """
for device in self.buckets.keys(): assert device in self.buckets.keys(), "Changing devices is not supported, because this would break OSSs state"
for bucket in self.buckets[device]: assert (
len(self.buckets.keys()) == 1
), "Several devices specified to begin with, incompatible with setting a single device here"
for _device in self.buckets.keys():
for bucket in self.buckets[_device]:
bucket.buffer.to(device=device, dtype=dtype, non_blocking=non_blocking) bucket.buffer.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.module.to(device) self.module.to(device=device, dtype=dtype, non_blocking=non_blocking)
def refresh_trainable(self) -> None: def refresh_trainable(self) -> None:
""" If the module trainability has changed, update all the assumptions """ """ If the module trainability has changed, update all the assumptions """
...@@ -320,7 +326,7 @@ class ShardedDataParallel(nn.Module): ...@@ -320,7 +326,7 @@ class ShardedDataParallel(nn.Module):
See :meth:`torch.optim.Optimizer.zero_grad` for details. See :meth:`torch.optim.Optimizer.zero_grad` for details.
""" """
for index, trainable_param in enumerate(self._trainable_params): for index, trainable_param in enumerate(self._all_params):
if set_to_none and not self._should_bucket_grad[index]: if set_to_none and not self._should_bucket_grad[index]:
trainable_param.grad = None trainable_param.grad = None
elif trainable_param.grad is not None: elif trainable_param.grad is not None:
...@@ -339,6 +345,7 @@ class ShardedDataParallel(nn.Module): ...@@ -339,6 +345,7 @@ class ShardedDataParallel(nn.Module):
old_should_accumulate_grads = self.should_accumulate_grads old_should_accumulate_grads = self.should_accumulate_grads
self.should_accumulate_grads = True self.should_accumulate_grads = True
yield yield
self.accumulate_grads_flipped = self.should_accumulate_grads != old_should_accumulate_grads
self.should_accumulate_grads = old_should_accumulate_grads self.should_accumulate_grads = old_should_accumulate_grads
@torch.no_grad() @torch.no_grad()
...@@ -352,13 +359,19 @@ class ShardedDataParallel(nn.Module): ...@@ -352,13 +359,19 @@ class ShardedDataParallel(nn.Module):
assert self._bucket_list is not None assert self._bucket_list is not None
for bucket in self._bucket_list: for bucket in self._bucket_list:
assert not self.training or self.should_accumulate_grads or bucket.sent, ( assert (
"A bucket failed to be sent, probably unused parameters." self.accumulate_grads_flipped or not self.training or self.should_accumulate_grads or bucket.sent
+ "Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-" ), (
"A bucket failed to be sent, probably unused parameters. "
+ "Either mark the unused parameter as not trainable (`.requires_grad = False`) "
+ "or de-activate ShardedDDP buckets -set `reduce_buffer_size` to 0-"
) )
bucket.reset() bucket.reset()
if not self.should_accumulate_grads:
self.accumulate_grads_flipped = False
def _find_rank(self, param: Parameter) -> Tuple[OSS, int]: def _find_rank(self, param: Parameter) -> Tuple[OSS, int]:
""" Look up where this parameter belongs to """ """ Look up where this parameter belongs to """
for optim in self.sharded_optimizers: for optim in self.sharded_optimizers:
...@@ -394,10 +407,12 @@ class ShardedDataParallel(nn.Module): ...@@ -394,10 +407,12 @@ class ShardedDataParallel(nn.Module):
param.grad = None param.grad = None
# Async reduce for this buffer, log the future # Async reduce for this buffer, log the future
dst_global_rank = OSS.get_global_rank(self.process_group, dst_rank)
self._work_handles.append( self._work_handles.append(
Workhandle( Workhandle(
handle=dist.reduce( handle=dist.reduce(
tensor=param.grad.data, dst=dst_rank, group=self.process_group, async_op=True tensor=param.grad.data, dst=dst_global_rank, group=self.process_group, async_op=True
), ),
callback=cleanup, callback=cleanup,
) )
...@@ -435,7 +450,10 @@ class ShardedDataParallel(nn.Module): ...@@ -435,7 +450,10 @@ class ShardedDataParallel(nn.Module):
self._work_handles.append( self._work_handles.append(
Workhandle( Workhandle(
handle=dist.reduce( handle=dist.reduce(
tensor=bucket.buffer, dst=dst_rank, group=self.process_group, async_op=True, tensor=bucket.buffer,
dst=bucket.destination,
group=self.process_group,
async_op=True,
), ),
callback=None, callback=None,
) )
...@@ -470,33 +488,11 @@ class ShardedDataParallel(nn.Module): ...@@ -470,33 +488,11 @@ class ShardedDataParallel(nn.Module):
p_tmp = param.expand_as(param) p_tmp = param.expand_as(param)
assert p_tmp.grad_fn is not None assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] grad_acc = p_tmp.grad_fn.next_functions[0][0]
dst_rank = OSS.get_global_rank(self.process_group, self._trainable_param_to_rank[param]) dst_rank = self._trainable_param_to_rank[param]
grad_acc.register_hook(self._get_reduce_fn(index, param, dst_rank)) grad_acc.register_hook(self._get_reduce_fn(index, param, dst_rank))
self._grad_accs.append(grad_acc) # keep this function in scope self._grad_accs.append(grad_acc) # keep this function in scope
# Add a hook on the module to flush the buckets, if needed
if self.use_buckets:
def bucket_flush(*_: Any) -> None:
assert self._bucket_list is not None
handle = None
for bucket in self._bucket_list:
if not bucket.sent:
# Reduce the bucket. Some parameters went unused and this bucket was not flushed
bucket.buffer.mul_(self.world_size_scaling)
bucket.sent = True
handle = dist.reduce(
tensor=bucket.buffer, dst=bucket.destination, group=self.process_group, async_op=True,
)
# Only wait on the last handle
if handle:
handle.wait()
self.module.register_backward_hook(bucket_flush)
@torch.no_grad() @torch.no_grad()
def _sync_params_and_buffers(self) -> None: def _sync_params_and_buffers(self) -> None:
""" """
...@@ -545,7 +541,7 @@ class ShardedDataParallel(nn.Module): ...@@ -545,7 +541,7 @@ class ShardedDataParallel(nn.Module):
for param in self._trainable_params: for param in self._trainable_params:
device = param.device device = param.device
dst_rank = OSS.get_global_rank(self.process_group, self._trainable_param_to_rank[param]) dst_rank = self._trainable_param_to_rank[param]
if param.device not in self.buckets.keys(): if param.device not in self.buckets.keys():
self.buckets[param.device] = [ self.buckets[param.device] = [
...@@ -554,7 +550,7 @@ class ShardedDataParallel(nn.Module): ...@@ -554,7 +550,7 @@ class ShardedDataParallel(nn.Module):
] ]
bucket = self.buckets[device][dst_rank] bucket = self.buckets[device][dst_rank]
bucket.destination = dst_rank bucket.destination = OSS.get_global_rank(self.process_group, dst_rank)
# Criteria to decide whether this parameter is to be bucketed or not: # Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket # - enough room in the bucket
......
...@@ -412,7 +412,7 @@ class OSS(Optimizer): ...@@ -412,7 +412,7 @@ class OSS(Optimizer):
def refresh_trainable(self) -> None: def refresh_trainable(self) -> None:
""" Updates the partitioning and communication patterns if the trainability (`requires_grad`) """ Updates the partitioning and communication patterns if the trainability (`requires_grad`)
of some parameters changed of some parameters changed.
""" """
# Create the optim which will work on the param shard # Create the optim which will work on the param shard
......
...@@ -54,7 +54,7 @@ skip_if_single_gpu = pytest.mark.skipif( ...@@ -54,7 +54,7 @@ skip_if_single_gpu = pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason="multiple GPUs required" not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason="multiple GPUs required"
) )
skip_if_less_four_gpu = pytest.mark.skipif( skip_if_less_than_four_gpu = pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason="4 GPUs or more required" not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason="4 GPUs or more required"
) )
...@@ -67,6 +67,11 @@ skip_if_py39_no_cuda = pytest.mark.skipif( ...@@ -67,6 +67,11 @@ skip_if_py39_no_cuda = pytest.mark.skipif(
reason="Python3.9 wo CUDA is skipped", reason="Python3.9 wo CUDA is skipped",
) )
available_devices = ["cpu"]
if torch.cuda.is_available():
available_devices.append("cuda")
_, filename_mpi = tempfile.mkstemp() _, filename_mpi = tempfile.mkstemp()
...@@ -418,3 +423,31 @@ def check_same_model_params(model_a: torch.nn.Module, model_b: torch.nn.Module, ...@@ -418,3 +423,31 @@ def check_same_model_params(model_a: torch.nn.Module, model_b: torch.nn.Module,
for b_a, b_b in zip(model_a.buffers(), model_b.buffers()): for b_a, b_b in zip(model_a.buffers(), model_b.buffers()):
assert torch.allclose(b_a, b_b), f"Model buffers differ {b_a} - {b_b}\n" + message assert torch.allclose(b_a, b_b), f"Model buffers differ {b_a} - {b_b}\n" + message
def check_same_models_across_ranks(
model: torch.nn.Module, process_group: Any, params_should_be_equal: bool, check_broadcast_buffers: bool
) -> None:
world_size = dist.get_world_size(process_group)
rank = dist.get_rank(process_group)
for param in model.parameters():
# collect the params across the rank
receptacle = [param.clone() for _ in range(world_size)]
dist.all_gather(receptacle, param, group=process_group)
if rank == 0:
for sync_p in receptacle[1:]:
assert not params_should_be_equal or torch.all(
torch.eq(receptacle[0], sync_p)
), "Models differ in between ranks"
# Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
if check_broadcast_buffers:
for buffer in model.buffers():
receptacle = [buffer.clone() for _ in range(world_size)]
dist.all_gather(receptacle, buffer, group=process_group)
if rank == 0:
for sync_b in receptacle[1:]:
assert not params_should_be_equal or torch.all(
torch.eq(receptacle[0], sync_b)
), "Models differ in between ranks"
...@@ -10,9 +10,9 @@ Testing ShardedDDP ...@@ -10,9 +10,9 @@ Testing ShardedDDP
from contextlib import suppress from contextlib import suppress
import copy import copy
import tempfile import tempfile
from typing import List
import numpy as np import numpy as np
import pytest
import torch import torch
from torch.cuda.amp import GradScaler as TorchGradScaler from torch.cuda.amp import GradScaler as TorchGradScaler
import torch.distributed as dist import torch.distributed as dist
...@@ -25,126 +25,126 @@ from fairscale.optim import OSS ...@@ -25,126 +25,126 @@ from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils.testing import ( from fairscale.utils.testing import (
GPT2, GPT2,
available_devices,
check_same_model_params, check_same_model_params,
skip_if_less_four_gpu, check_same_models_across_ranks,
skip_if_less_than_four_gpu,
skip_if_no_cuda, skip_if_no_cuda,
skip_if_py38, skip_if_py38,
skip_if_single_gpu, skip_if_single_gpu,
) )
def run_one_step(rank, world_size, backend, device, temp_file_name): def _get_mlp():
url = "file://" + temp_file_name return Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
class _DoubleInput(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = _get_mlp()
def forward(self, x, y):
x1 = self.mlp(x)
x2 = self.mlp(y)
return torch.cat((x1, x2), dim=1)
def run_one_step(
rank, world_size, backend, device, temp_file_name, broadcast_buffers, grad_accumulation, reduce_buffer_size,
):
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
if device == torch.device("cuda"): if device == torch.device("cuda"):
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
torch.manual_seed(rank) torch.manual_seed(rank)
np.random.seed(rank) np.random.seed(rank)
def check(broadcast_buffers: bool, grad_accumulation: bool = False) -> None: # Any model works. Add one different buffer per rank
# Any model works. Add one different buffer per rank model = _get_mlp()
model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3)) model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device)
model.to(device)
next(model.parameters()).requires_grad = False # Test non-trainable parameters
next(model.parameters()).requires_grad = False # Test non-trainable parameters optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(
model, optimizer, broadcast_buffers=broadcast_buffers, reduce_buffer_size=reduce_buffer_size
)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99) # The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
ddp_model = ShardedDataParallel(model, optimizer, broadcast_buffers=broadcast_buffers) check_same_models_across_ranks(
ddp_model, dist.group.WORLD, params_should_be_equal=True, check_broadcast_buffers=broadcast_buffers
)
def check_same_model_params(same_params: bool): # Optim loop
# Check that all the params are the same on all ranks def closure():
# This should be true with and without broadcast_buffers, we don't have any real buffer here optimizer.zero_grad()
receptacle: List[torch.Tensor] = []
with ddp_model.no_sync() if grad_accumulation else suppress():
if dist.get_backend() != "nccl": input_tensor = torch.rand((64, 2)).to(device)
for pg in optimizer.param_groups: loss = ddp_model(input_tensor).abs().sum()
for p in pg["params"]: loss.backward()
# Check the params return loss
receptacle = [p.clone() for _ in range(world_size)] if rank == 0 else []
dist.gather(p, receptacle, dst=0) # The models should stay the same in between the ranks
if rank == 0: for i in range(5):
for sync_p in receptacle[1:]: _ = optimizer.step(closure=closure)
if same_params: # when running on cpu/gloo the "nodes" are not really different
assert torch.all(torch.eq(receptacle[0], sync_p)), "Models differ in between ranks" same_params = device == torch.device("cpu") or grad_accumulation
else: check_same_models_across_ranks(
assert not torch.all( ddp_model, dist.group.WORLD, params_should_be_equal=same_params, check_broadcast_buffers=broadcast_buffers
torch.eq(receptacle[0], sync_p) )
), "Gradients should not have been synced"
# Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
if broadcast_buffers:
for b in ddp_model.buffers():
receptacle = [b.clone() for _ in range(world_size)] if rank == 0 else []
dist.gather(b, receptacle, dst=0)
if rank == 0:
for sync_b in receptacle[1:]:
if same_params:
assert torch.all(torch.eq(receptacle[0], sync_b)), "Models differ in between ranks"
else:
assert not torch.all(
torch.eq(receptacle[0], sync_b)
), "Gradients should not have been synced"
assert b.cpu().item() == 0.0
# The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
check_same_model_params(same_params=True)
# Optim loop
def closure():
optimizer.zero_grad()
with ddp_model.no_sync() if grad_accumulation else suppress():
input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor).abs().sum()
loss.backward()
return loss
# The models should stay the same in between the ranks
for i in range(5):
_ = optimizer.step(closure=closure)
# when running on cpu/gloo the "nodes" are not really different
same_params = device == torch.device("cpu") or grad_accumulation
check_same_model_params(same_params=same_params)
check(broadcast_buffers=False)
check(broadcast_buffers=True)
check(broadcast_buffers=False, grad_accumulation=True)
check(broadcast_buffers=True, grad_accumulation=True)
dist.destroy_process_group() dist.destroy_process_group()
def run_test(backend, device, world_size=2): def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size):
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_one_step, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True) mp.spawn(
run_one_step,
args=(world_size, backend, device, temp_file_name, broadcast_buffers, grad_accumulation, reduce_buffer_size),
nprocs=world_size,
join=True,
)
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_single_gpu @skip_if_single_gpu
def test_step_gpu(): @pytest.mark.parametrize("broadcast_buffers", [True, False])
run_test(backend=dist.Backend.NCCL, device=torch.device("cuda")) @pytest.mark.parametrize("grad_accumulation", [True, False])
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_step_gpu(broadcast_buffers, grad_accumulation, reduce_buffer_size):
world_size = 2
run_test(
dist.Backend.NCCL, torch.device("cuda"), world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size
)
@skip_if_py38 @skip_if_py38
def test_step_cpu(): @pytest.mark.parametrize("broadcast_buffers", [True, False])
run_test(backend=dist.Backend.GLOO, device=torch.device("cpu")) @pytest.mark.parametrize("grad_accumulation", [True, False])
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_step_cpu(broadcast_buffers, grad_accumulation, reduce_buffer_size):
world_size = 2
run_test(
dist.Backend.GLOO, torch.device("cpu"), world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size
)
def run_ddp_parity(rank, world_size, backend, temp_file_name): def run_ddp_parity(
url = "file://" + temp_file_name rank, world_size, backend, temp_file_name, reduce_buffer_size, grad_accumulation, change_train_graph
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) ):
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
device = torch.device("cuda") device = torch.device("cuda")
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
torch.manual_seed(rank) torch.manual_seed(rank)
np.random.seed(rank) np.random.seed(rank)
NUMBER_BATCHS = 5 NUMBER_BATCHS = 5
INPUTS = 2 BATCH_SIZE = 8
BATCH_SIZE = 32
def check_parity(amp: bool, accumulate: bool, change_train_graph: bool, manual_reduction: bool): def check_parity(amp: bool, manual_reduction: bool):
# The API should be the exact same in between the sharded and non-sharded variants, generic closure # The API should be the exact same in between the sharded and non-sharded variants, generic closure
def closure(model, scaler, input_tensor, should_accumulate, _manual_reduction=False): def closure(model, scaler, input_tensor, should_accumulate, _manual_reduction=False):
...@@ -174,7 +174,7 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -174,7 +174,7 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
model.reduce() model.reduce()
# Any model works. Add one different buffer per rank # Any model works. Add one different buffer per rank
model = Sequential(Linear(INPUTS, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3)) model = _get_mlp()
model.register_buffer("test_buffer", torch.ones((1)) * rank) model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device) model.to(device)
...@@ -182,13 +182,16 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -182,13 +182,16 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
# properly reassigned when/if this changes # properly reassigned when/if this changes
next(model.parameters()).requires_grad = False next(model.parameters()).requires_grad = False
sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-5, momentum=0.99) sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-4, momentum=0.99)
sharded_ddp_model = ShardedDataParallel( sharded_ddp_model = ShardedDataParallel(
module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True module=model,
sharded_optimizer=sharded_optimizer,
broadcast_buffers=True,
reduce_buffer_size=reduce_buffer_size,
) )
ddp_model_single = copy.deepcopy(model) ddp_model_single = copy.deepcopy(model)
ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-5, momentum=0.99) ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-4, momentum=0.99)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True) ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True)
ddp_scaler = TorchGradScaler() if amp else None ddp_scaler = TorchGradScaler() if amp else None
...@@ -199,14 +202,18 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -199,14 +202,18 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
# Typical training loop, check that we get the exact same results as DDP # Typical training loop, check that we get the exact same results as DDP
for i in range(NUMBER_BATCHS): for i in range(NUMBER_BATCHS):
input_tensor = torch.rand((BATCH_SIZE, INPUTS)).to(device) input_tensor = torch.rand((BATCH_SIZE, 2)).to(device)
def closure_ddp(input_tensor=input_tensor): def closure_ddp(input_tensor=input_tensor):
return closure(ddp_model, ddp_scaler, input_tensor, accumulate) return closure(ddp_model, ddp_scaler, input_tensor, grad_accumulation)
def closure_sharded(input_tensor=input_tensor): def closure_sharded(input_tensor=input_tensor):
return closure( return closure(
sharded_ddp_model, sharded_ddp_scaler, input_tensor, accumulate, _manual_reduction=manual_reduction sharded_ddp_model,
sharded_ddp_scaler,
input_tensor,
grad_accumulation,
_manual_reduction=manual_reduction,
) )
# Step/scale both # Step/scale both
...@@ -234,77 +241,82 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name): ...@@ -234,77 +241,82 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
next(ddp_model.parameters()).requires_grad = not next(ddp_model.parameters()).requires_grad next(ddp_model.parameters()).requires_grad = not next(ddp_model.parameters()).requires_grad
check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke") check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke")
# Test all combinations: AMP, Accumulate, Change train graph # Test all combinations: AMP, Accumulate, Change train graph, reduce buckets
amp_tests = [False] amp_tests = [False]
if hasattr(torch.cuda.amp, "autocast"): if hasattr(torch.cuda.amp, "autocast"):
amp_tests.append(True) amp_tests.append(True)
for accumulate in [False, True]: manual_reductions = [False, True] if not grad_accumulation and not change_train_graph else [False]
for change_train_graph in [False, True]: for manual_reduction in manual_reductions:
manual_reductions = [False, True] if not accumulate and not change_train_graph else [False] for amp in amp_tests:
for manual_reduction in manual_reductions: print(
for amp in amp_tests: f"Checking configuration: accumulate {grad_accumulation}"
print( + f" - change train graph {change_train_graph}"
f"Checking configuration: accumulate {accumulate} - change train graph {change_train_graph} - amp {amp} - manual reduction {manual_reduction}" + f" - amp {amp}"
) + f" - manual reduction {manual_reduction}"
check_parity( + f" - buffers {reduce_buffer_size}",
amp=amp, flush=True,
accumulate=accumulate, )
change_train_graph=change_train_graph, check_parity(
manual_reduction=manual_reduction, amp=amp, manual_reduction=manual_reduction,
) )
dist.destroy_process_group() dist.destroy_process_group()
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_single_gpu @skip_if_single_gpu
def test_ddp_parity(): @pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
temp_file_name = tempfile.mkstemp()[1] @pytest.mark.parametrize("grad_accumulation", [True, False])
@pytest.mark.parametrize("change_train_graph", [True, False])
def test_ddp_parity(reduce_buffer_size, grad_accumulation, change_train_graph):
world_size = torch.cuda.device_count() world_size = torch.cuda.device_count()
backend = dist.Backend.NCCL backend = dist.Backend.NCCL
mp.spawn(run_ddp_parity, args=(world_size, backend, temp_file_name), nprocs=world_size, join=True) mp.spawn(
run_ddp_parity,
args=(world_size, backend, tempfile.mkstemp()[1], reduce_buffer_size, grad_accumulation, change_train_graph),
nprocs=world_size,
join=True,
)
def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name): def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name, reduce_buffer_size):
url = "file://" + temp_file_name dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
device = torch.device("cuda") device = torch.device("cuda")
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
torch.manual_seed(rank) torch.manual_seed(rank)
np.random.seed(rank) # Any model works. Add one different buffer per rank np.random.seed(rank) # Any model works. Add one different buffer per rank
model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3)) BATCHS = 20
model = _get_mlp()
model.register_buffer("test_buffer", torch.ones((1)) * rank) model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device) model.to(device)
n_half_params = len(list(model.parameters())) // 2 n_half_params = len(list(model.parameters())) // 2
optim_settings = {"lr": 1e-3, "momentum": 0.99}
sharded_optimizer = OSS( sharded_optimizer = OSS(params=list(model.parameters())[:n_half_params], optim=torch.optim.SGD, **optim_settings)
params=list(model.parameters())[:n_half_params], optim=torch.optim.SGD, lr=1e-3, momentum=0.99 sharded_optimizer_2 = OSS(params=list(model.parameters())[n_half_params:], optim=torch.optim.SGD, **optim_settings)
)
sharded_optimizer_2 = OSS(
params=list(model.parameters())[n_half_params:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99
)
sharded_ddp_model = ShardedDataParallel(module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True) sharded_ddp_model = ShardedDataParallel(
module=model,
sharded_optimizer=[sharded_optimizer, sharded_optimizer_2],
broadcast_buffers=True,
reduce_buffer_size=reduce_buffer_size,
)
ddp_model_single = copy.deepcopy(model) ddp_model_single = copy.deepcopy(model)
ddp_optimizer = torch.optim.SGD(list(ddp_model_single.parameters())[:n_half_params], lr=1e-3, momentum=0.99) ddp_optimizer = torch.optim.SGD(list(ddp_model_single.parameters())[:n_half_params], **optim_settings)
ddp_optimizer_2 = torch.optim.SGD(list(ddp_model_single.parameters())[n_half_params:], lr=1e-3, momentum=0.99) ddp_optimizer_2 = torch.optim.SGD(list(ddp_model_single.parameters())[n_half_params:], **optim_settings)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True) ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)
def check_same_model_params(): check_same_model_params(
for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups): sharded_ddp_model,
for p, ddp_p in zip(pg["params"], ddp_pg["params"]): ddp_model,
assert torch.allclose( f"DDP parity two optim test failing. differing at startup, Buffers {reduce_buffer_size}",
p, ddp_p, atol=1e-3 )
), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}"
for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
assert torch.allclose(b, ddp_b, atol=1e-3), "Model buffers differ in between DDP and ShardedDDP"
check_same_model_params() # The models should stay the same in between the ranks
for i in range(20): for i in range(BATCHS):
input_tensor = torch.rand((64, 2)).to(device) input_tensor = torch.rand((64, 2)).to(device)
# Run DDP # Run DDP
...@@ -314,6 +326,7 @@ def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name): ...@@ -314,6 +326,7 @@ def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name):
ddp_loss.backward() ddp_loss.backward()
ddp_optimizer.step() ddp_optimizer.step()
ddp_optimizer_2.step() ddp_optimizer_2.step()
torch.cuda.synchronize(device)
# Run Sharded # Run Sharded
sharded_optimizer.zero_grad() sharded_optimizer.zero_grad()
...@@ -322,43 +335,40 @@ def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name): ...@@ -322,43 +335,40 @@ def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name):
sharded_loss.backward() sharded_loss.backward()
sharded_optimizer.step() sharded_optimizer.step()
sharded_optimizer_2.step() sharded_optimizer_2.step()
check_same_model_params() torch.cuda.synchronize(device)
check_same_model_params(
sharded_ddp_model, ddp_model, f"DDP parity two optim test failing, step {i}, buffers {reduce_buffer_size}",
)
dist.destroy_process_group() dist.destroy_process_group()
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_single_gpu @skip_if_single_gpu
def test_ddp_parity_two_optim(): @pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
temp_file_name = tempfile.mkstemp()[1] def test_ddp_parity_two_optim(reduce_buffer_size):
world_size = 2 world_size = 2
backend = dist.Backend.NCCL backend = dist.Backend.NCCL
mp.spawn(run_ddp_parity_two_optim, args=(world_size, backend, temp_file_name), nprocs=world_size, join=True) mp.spawn(
run_ddp_parity_two_optim,
args=(world_size, backend, tempfile.mkstemp()[1], reduce_buffer_size),
nprocs=world_size,
join=True,
)
def run_test_two_inputs(rank, world_size, backend, device, temp_file_name): def run_test_two_inputs(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
url = "file://" + temp_file_name dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) if device == "cuda":
if device == torch.device("cuda"):
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
torch.manual_seed(rank) torch.manual_seed(rank)
np.random.seed(rank) np.random.seed(rank)
class _DoubleInput(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
def forward(self, x, y):
x1 = self.mlp(x)
x2 = self.mlp(y)
return torch.cat((x1, x2), dim=1)
model = _DoubleInput().to(device) model = _DoubleInput().to(device)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer, reduce_buffer_size=reduce_buffer_size)
ddp_model = ShardedDataParallel(model, optimizer)
# Optim loop # Optim loop
def closure(): def closure():
...@@ -374,25 +384,32 @@ def run_test_two_inputs(rank, world_size, backend, device, temp_file_name): ...@@ -374,25 +384,32 @@ def run_test_two_inputs(rank, world_size, backend, device, temp_file_name):
dist.destroy_process_group() dist.destroy_process_group()
def test_inputs(): @pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
@pytest.mark.parametrize("backend", ["gloo", "nccl"])
@pytest.mark.parametrize("device", available_devices)
def test_inputs(reduce_buffer_size, backend, device):
# Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs # Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size = 2 world_size = 2
backend = "gloo" if backend == "nccl" and device == "cpu":
temp_file_name = tempfile.mkstemp()[1] pytest.skip("Incompatible combination, or cuda not available")
device = "cpu" return
mp.spawn(run_test_two_inputs, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
mp.spawn(
run_test_two_inputs,
args=(world_size, backend, device, tempfile.mkstemp()[1], reduce_buffer_size),
nprocs=world_size,
join=True,
)
def test_ddp_attributes(): def test_ddp_attributes():
# Check that ShardedDDP exposes the same attributes as Pytorch's DDP # Check that ShardedDDP exposes the same attributes as Pytorch's DDP
# - is multi_device_module # - is multi_device_module
# - device_type # - device_type
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
url = "file://" + tempfile.mkstemp()[1]
dist.init_process_group(init_method=url, backend="gloo", rank=0, world_size=1)
model = Sequential(Linear(2, 3), Linear(3, 3)) model = Sequential(Linear(2, 3), Linear(3, 3))
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer) ddp_model = ShardedDataParallel(model, optimizer)
assert hasattr(ddp_model, "is_multi_device_module") assert hasattr(ddp_model, "is_multi_device_module")
...@@ -402,14 +419,12 @@ def test_ddp_attributes(): ...@@ -402,14 +419,12 @@ def test_ddp_attributes():
def test_random_attributes(): def test_random_attributes():
# Check that ShardedDDP exposes the original module's attributes # Check that ShardedDDP exposes the original module's attributes
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
url = "file://" + tempfile.mkstemp()[1]
dist.init_process_group(init_method=url, backend="gloo", rank=0, world_size=1)
model = Sequential(Linear(2, 3), Linear(3, 3)) model = Sequential(Linear(2, 3), Linear(3, 3))
model.banana = "sweet" model.banana = "sweet"
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer) ddp_model = ShardedDataParallel(model, optimizer)
assert hasattr(ddp_model, "banana") assert hasattr(ddp_model, "banana")
...@@ -418,41 +433,51 @@ def test_random_attributes(): ...@@ -418,41 +433,51 @@ def test_random_attributes():
dist.destroy_process_group() dist.destroy_process_group()
def run_test_device_change(rank, world_size, backend, device, temp_file_name): def run_test_device_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
# Check that the wrapped module can change devices # Check that the wrapped module can change devices
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
url = "file://" + temp_file_name model = Sequential(Linear(2, 3), Linear(3, 3)).cpu() # not device on purpose, test changing it after the fact
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(
model = Sequential(Linear(2, 3), Linear(3, 3)).cpu() model, optimizer, sync_models_at_startup=False, reduce_buffer_size=reduce_buffer_size
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99) )
ddp_model = ShardedDataParallel(model, optimizer) try:
ddp_model.to(device) ddp_model.to(device)
assert False, "Changing devices should be caught and not supported"
inputs = torch.rand((10, 2), device=device) except AssertionError:
outputs = ddp_model(inputs) # assert if the module has not been changed properly pass
loss = outputs.norm().backward()
dist.destroy_process_group() dist.destroy_process_group()
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_single_gpu @skip_if_single_gpu
def test_device_change(): @pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_device_change(reduce_buffer_size):
# Check that ShardedDDP is compatible with sync batch norm across multiple GPUs # Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
world_size = 2 world_size = 2
backend = "gloo" backend = "nccl"
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
device = "cuda" device = "cuda"
mp.spawn(run_test_device_change, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True) mp.spawn(
run_test_device_change,
args=(world_size, backend, device, temp_file_name, reduce_buffer_size),
nprocs=world_size,
join=True,
)
def run_test_training_change(rank, world_size, backend, device, temp_file_name): def run_test_training_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
url = "file://" + temp_file_name group = dist.init_process_group(
group = dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size
)
torch.cuda.set_device(rank)
model = Sequential(Linear(2, 3), Linear(3, 3)).to(device) model = Sequential(Linear(2, 3), Linear(3, 3)).to(device)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer, process_group=group) ddp_model = ShardedDataParallel(model, optimizer, process_group=group, reduce_buffer_size=reduce_buffer_size)
inputs = torch.rand((10, 2), device=device) inputs = torch.rand((10, 2), device=device)
outputs = ddp_model(inputs) # assert if the module has not been changed properly outputs = ddp_model(inputs) # assert if the module has not been changed properly
...@@ -465,23 +490,30 @@ def run_test_training_change(rank, world_size, backend, device, temp_file_name): ...@@ -465,23 +490,30 @@ def run_test_training_change(rank, world_size, backend, device, temp_file_name):
dist.destroy_process_group() dist.destroy_process_group()
def test_training_change(): @skip_if_no_cuda
world_size = 8 @skip_if_single_gpu
backend = "gloo" @pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_training_change(reduce_buffer_size):
world_size = 2
backend = "nccl"
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
device = "cpu" device = "cuda"
mp.spawn(run_test_training_change, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True) mp.spawn(
run_test_training_change,
args=(world_size, backend, device, temp_file_name, reduce_buffer_size),
nprocs=world_size,
join=True,
)
def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_name): def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_name):
url = "file://" + temp_file_name dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
model = Sequential(Linear(2, 3), torch.nn.BatchNorm1d(3), Linear(3, 3)).to(device) model = Sequential(Linear(2, 3), torch.nn.BatchNorm1d(3), Linear(3, 3)).to(device)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.to(device) # in pytorch 1.5 syncBN switches to the default device/cpu model.to(device) # in pytorch 1.5 syncBN switches to the default device/cpu
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer) ddp_model = ShardedDataParallel(model, optimizer)
assert isinstance(model[1], torch.nn.SyncBatchNorm) assert isinstance(model[1], torch.nn.SyncBatchNorm)
...@@ -504,29 +536,17 @@ def test_ddp_sync_batch_norm(): ...@@ -504,29 +536,17 @@ def test_ddp_sync_batch_norm():
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name): def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
url = "file://" + temp_file_name dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
if device == torch.device("cuda"): if device == torch.device("cuda"):
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
torch.manual_seed(rank) torch.manual_seed(rank)
np.random.seed(rank) np.random.seed(rank)
class _DoubleInput(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
def forward(self, x, y):
x1 = self.mlp(x)
x2 = self.mlp(y)
return torch.cat((x1, x2), dim=1)
model = _DoubleInput().to(device) model = _DoubleInput().to(device)
parameters = list(model.parameters()) parameters = list(model.parameters())
optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=0.01, momentum=0.99) optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=0.01, momentum=0.99) optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2]) ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2])
# Optim loop # Optim loop
...@@ -556,22 +576,21 @@ def test_two_optimizers(): ...@@ -556,22 +576,21 @@ def test_two_optimizers():
def run_test_gpt2(rank, world_size, backend, device, temp_file_name): def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
INPUT_DIM = 32 INPUT_DIM = 16
BACH_SIZE = 10 BACH_SIZE = 10
STEPS = 10 STEPS = 10
url = "file://" + temp_file_name url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
if device == torch.device("cuda"): torch.cuda.set_device(rank)
torch.cuda.set_device(rank)
torch.manual_seed(rank) torch.manual_seed(rank)
np.random.seed(rank) np.random.seed(rank)
model = GPT2( model = GPT2(
embed_dim=512, num_heads=2, num_layers=24, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2 embed_dim=256, num_heads=2, num_layers=12, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2
).to(device) ).to(device)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer) ddp_model = ShardedDataParallel(model, optimizer, reduce_buffer_size=0)
# Optim loop # Optim loop
def closure(): def closure():
...@@ -600,7 +619,7 @@ def test_gpt2(): ...@@ -600,7 +619,7 @@ def test_gpt2():
mp.spawn(run_test_gpt2, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True) mp.spawn(run_test_gpt2, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
def run_test_multiple_groups(rank, world_size, tempfile_name, backend): def run_test_multiple_groups(rank, world_size, tempfile_name, backend, reduce_buffer_size):
# Only work with the even ranks, to check that the global_rank indexing is properly used # Only work with the even ranks, to check that the global_rank indexing is properly used
dist.init_process_group(init_method="file://" + tempfile_name, backend=backend, rank=rank, world_size=world_size) dist.init_process_group(init_method="file://" + tempfile_name, backend=backend, rank=rank, world_size=world_size)
...@@ -635,17 +654,9 @@ def run_test_multiple_groups(rank, world_size, tempfile_name, backend): ...@@ -635,17 +654,9 @@ def run_test_multiple_groups(rank, world_size, tempfile_name, backend):
_ = optimizer.step(closure=closure) _ = optimizer.step(closure=closure)
# Check that all the params are the same on all ranks # Check that all the params are the same on all ranks
for pg in optimizer.param_groups: check_same_models_across_ranks(
for p in pg["params"]: model, process_group, params_should_be_equal=True, check_broadcast_buffers=True
receptacle = [p.clone() for _ in sub_group_ranks] )
dist.all_gather(receptacle, p, group=process_group)
if rank == 0:
for sync_p in receptacle[1:]:
assert torch.all(
torch.eq(receptacle[0], sync_p)
), "Models differ in between ranks {} - {}".format(
torch.norm(receptacle[0]), torch.norm(sync_p)
)
if rank in sub_group_ranks: if rank in sub_group_ranks:
# Model not-fitting in the broadcast bucket # Model not-fitting in the broadcast bucket
...@@ -654,24 +665,25 @@ def run_test_multiple_groups(rank, world_size, tempfile_name, backend): ...@@ -654,24 +665,25 @@ def run_test_multiple_groups(rank, world_size, tempfile_name, backend):
) )
# With SGD, Momentum is required to get a state to shard # With SGD, Momentum is required to get a state to shard
optimizer = OSS(model.parameters(), lr=0.1, momentum=0.99, group=process_group) optimizer = OSS(model.parameters(), group=process_group, lr=1e-3, momentum=0.99)
model = ShardedDataParallel(model, optimizer, process_group=process_group) model = ShardedDataParallel(
model, optimizer, process_group=process_group, reduce_buffer_size=reduce_buffer_size
)
check(optimizer, model) check(optimizer, model)
dist.destroy_process_group(process_group) dist.destroy_process_group(process_group)
@skip_if_less_four_gpu @skip_if_less_than_four_gpu
def test_multiple_groups(): @pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
@pytest.mark.parametrize("backend", ["gloo", "nccl"])
def test_multiple_groups(reduce_buffer_size, backend):
world_size = 4 world_size = 4
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
for backend in ["gloo", "nccl"]:
print("Testing backend ", backend)
mp.spawn(
run_test_multiple_groups, args=(world_size, temp_file_name, backend), nprocs=world_size, join=True,
)
mp.spawn( mp.spawn(
run_test_multiple_groups, args=(world_size, temp_file_name, "gloo"), nprocs=world_size, join=True, run_test_multiple_groups,
args=(world_size, temp_file_name, backend, reduce_buffer_size),
nprocs=world_size,
join=True,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment