Unverified Commit ad933b34 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] ShardedDataParallel with autoreduce (#157)

* rewrite using autograd and Variable execution queue to make the reduce automatic
* share buckets with OSS to remove duplication
* some speed still likely on the table since the speed vs. bucketing does not match expectations, could be a follow up
parent 35d4129f
...@@ -124,10 +124,12 @@ run_oss_benchmark: &run_oss_benchmark ...@@ -124,10 +124,12 @@ run_oss_benchmark: &run_oss_benchmark
python benchmarks/oss.py --check_regression --world_size 4 --reference_speed 760 --reference_memory 1120 --reference_loss 0.023 python benchmarks/oss.py --check_regression --world_size 4 --reference_speed 760 --reference_memory 1120 --reference_loss 0.023
run_oss_gloo: &run_oss_gloo run_oss_gloo: &run_oss_gloo
- run: - run:
name: Run OSS with Gloo name: Run OSS with Gloo
command: | command: |
python benchmarks/oss.py --gloo --optim_type oss_ddp --epochs 3 python benchmarks/oss.py --gloo --optim_type oss_ddp --epochs 2
python benchmarks/oss.py --gloo --optim_type oss_sharded_ddp --epochs 2
run_oss_amp: &run_oss_amp run_oss_amp: &run_oss_amp
- run: - run:
......
...@@ -97,19 +97,10 @@ def train( ...@@ -97,19 +97,10 @@ def train(
scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else ShardedGradScaler()) if args.amp else None scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else ShardedGradScaler()) if args.amp else None
if optim_type == OptimType.oss_sharded_ddp: if optim_type == OptimType.oss_sharded_ddp:
model = ShardedDDP( optimizer = OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
model, model = ShardedDDP(model, optimizer)
optimizer=OPTIM,
optimizer_params={"lr": 1e-4, "momentum": 0.9},
world_size=args.world_size,
broadcast_buffers=True,
)
optimizer = model.sharded_optimizer
else: else:
if args.cpu: device_ids = None if args.cpu else [rank]
device_ids = None
else:
device_ids = [rank]
model = DDP(model, device_ids=device_ids, find_unused_parameters=False) # type: ignore model = DDP(model, device_ids=device_ids, find_unused_parameters=False) # type: ignore
optimizer = ( optimizer = (
OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9) OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
...@@ -120,6 +111,7 @@ def train( ...@@ -120,6 +111,7 @@ def train(
# Reset the memory use counter # Reset the memory use counter
if not args.cpu: if not args.cpu:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(rank) torch.cuda.reset_peak_memory_stats(rank)
torch.cuda.synchronize(rank) torch.cuda.synchronize(rank)
...@@ -159,9 +151,6 @@ def train( ...@@ -159,9 +151,6 @@ def train(
loss = loss_fn(outputs, data["label"]) loss = loss_fn(outputs, data["label"])
loss.backward() loss.backward()
if optim_type == OptimType.oss_sharded_ddp:
model.reduce()
if args.debug and rank == 0 and next(model.parameters()).grad is not None: if args.debug and rank == 0 and next(model.parameters()).grad is not None:
logging.debug( logging.debug(
"after BW: param {} -- grad {}".format( "after BW: param {} -- grad {}".format(
......
...@@ -8,3 +8,4 @@ API Reference ...@@ -8,3 +8,4 @@ API Reference
optim/oss optim/oss
optim/grad_scaler optim/grad_scaler
nn/pipe nn/pipe
nn/sharded_ddp
ShardedDataParallel
====================
.. autoclass:: fairscale.nn.ShardedDataParallel
:members:
:undoc-members:
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .data_parallel import ShardedDataParallel
from .moe import MOELayer, Top2Gate from .moe import MOELayer, Top2Gate
from .pipe import LazyModule, Pipe, PipeRPCWrapper from .pipe import LazyModule, Pipe, PipeRPCWrapper
__all__ = ["Pipe", "PipeRPCWrapper", "Top2Gate", "LazyModule"] __all__ = ["Pipe", "PipeRPCWrapper", "Top2Gate", "LazyModule", "ShardedDataParallel"]
This diff is collapsed.
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Any, Dict, Optional from typing import Dict
import torch import torch
from torch.cuda.amp import GradScaler as TorchGradScaler from torch.cuda.amp import GradScaler as TorchGradScaler
...@@ -32,15 +32,15 @@ class ShardedGradScaler(TorchGradScaler): ...@@ -32,15 +32,15 @@ class ShardedGradScaler(TorchGradScaler):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def step(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> Optional[float]: def unscale_(self, optimizer: Optimizer) -> None:
assert isinstance(optimizer, OSS), "ShardedGradScaler is to be used in combination with a sharded optimizer" assert isinstance(optimizer, OSS), "ShardedGradScaler is to be used in combination with a sharded optimizer"
# Re-use the GradSCaler machinery, but make sure that the status is sync'ed in between the ranks # Call the upstream unscale_ method which will only act on this rank's gradients
super().unscale_(optimizer)
# Synchronize the detected inf across the ranks
optimizer_state = self._per_optimizer_states[id(optimizer)] optimizer_state = self._per_optimizer_states[id(optimizer)]
handles = [dist.all_reduce(v, async_op=True) for v in optimizer_state["found_inf_per_device"].values()] handles = [dist.all_reduce(v, async_op=True) for v in optimizer_state["found_inf_per_device"].values()]
# Make sure that the calls are done before moving out # Make sure that the calls are done before moving out
_ = list(map(lambda x: x.wait(), handles)) _ = list(map(lambda x: x.wait(), handles))
# Call Torch's GradScaler in turn, states have been synchronized across ranks
return super().step(optimizer, *args, **kwargs)
...@@ -16,7 +16,7 @@ import torch.distributed as dist ...@@ -16,7 +16,7 @@ import torch.distributed as dist
from torch.nn import Parameter from torch.nn import Parameter
from torch.optim import SGD, Optimizer from torch.optim import SGD, Optimizer
from .utils import broadcast_object, recursive_copy_to_device from .utils import Bucket, Workhandle, broadcast_object, recursive_copy_to_device
__all__ = ["OSS"] __all__ = ["OSS"]
...@@ -73,7 +73,7 @@ class OSS(Optimizer): ...@@ -73,7 +73,7 @@ class OSS(Optimizer):
super().__init__(params, default) super().__init__(params, default)
self.in_super_constructor = False self.in_super_constructor = False
# Partition information. lazy evaluation, computed if requested # Partition information. lazy evaluation, computed when requested
self._per_device_params: Dict[torch.device, List[List[Parameter]]] = OrderedDict() # device, rank, params self._per_device_params: Dict[torch.device, List[List[Parameter]]] = OrderedDict() # device, rank, params
self._param_rank: Dict[torch.Tensor, int] = {} self._param_rank: Dict[torch.Tensor, int] = {}
self._partition_parameters: List[List[dict]] = [] self._partition_parameters: List[List[dict]] = []
...@@ -88,22 +88,26 @@ class OSS(Optimizer): ...@@ -88,22 +88,26 @@ class OSS(Optimizer):
# - Sync local and global param_groups keys # - Sync local and global param_groups keys
for global_group, local_group in zip(self.param_groups, self.optim.param_groups): for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
for k, v in local_group.items(): for key, value in local_group.items():
if k != "params": if key != "params":
global_group[k] = v global_group[key] = value
# Optional consolidated optimizer state # Optional consolidated optimizer state
self._all_states: List[Dict[str, Any]] = [] self._all_states: List[Dict[str, Any]] = []
# Current default device is set by the parameters allocated to this rank # Current default device is set by the parameters allocated to this rank
self._device = self.partition_parameters()[self.rank][0]["params"][0].device self._device = self.partition_parameters()[self.rank][0]["params"][0].device
self._broadcast_buffers: Dict[torch.device, List[torch.Tensor]] = {} self.buckets: Dict[torch.device, List[Bucket]] = {}
for device, per_device in self.per_device_params.items(): for device, per_device in self.per_device_params.items():
# Allocate one buffer per rank and per device to group the small parameters # Allocate one buffer per rank and per device to group the small parameters
self._broadcast_buffers[device] = [ self.buckets[device] = [
torch.zeros(broadcast_buffer_size, dtype=per_device[0][0].dtype, device=device) Bucket(buffer=torch.zeros(broadcast_buffer_size, dtype=per_device[0][0].dtype, device=device))
for _ in range(len(per_device)) for _ in range(len(per_device))
] ]
self.should_bucket_param: Dict[torch.Tensor, bool] = {}
self.work_handles: List[Workhandle] = []
self._max_work_handles = -1
self._setup_bucket_strategy()
# Partition helpers # Partition helpers
def partition_parameters(self) -> List[List[dict]]: def partition_parameters(self) -> List[List[dict]]:
...@@ -150,9 +154,9 @@ class OSS(Optimizer): ...@@ -150,9 +154,9 @@ class OSS(Optimizer):
self._per_device_params[device][self.param_to_rank[param]] += [param] self._per_device_params[device][self.param_to_rank[param]] += [param]
# Sort param_lists by size # Sort param_lists by size
for k in self._per_device_params.keys(): for device in self._per_device_params.keys():
for r in self._per_device_params[k]: for rank_params in self._per_device_params[device]:
r.sort(key=lambda x: x.numel()) rank_params.sort(key=lambda x: x.numel())
return self._per_device_params return self._per_device_params
...@@ -164,6 +168,9 @@ class OSS(Optimizer): ...@@ -164,6 +168,9 @@ class OSS(Optimizer):
for param_group in param_groups: for param_group in param_groups:
for param in param_group["params"]: for param in param_group["params"]:
self._param_rank[param] = rank self._param_rank[param] = rank
logging.debug("ZeRO: Parameters dispatched to ranks %s " % list(self._param_rank.values()))
return self._param_rank return self._param_rank
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs. # NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
...@@ -181,20 +188,16 @@ class OSS(Optimizer): ...@@ -181,20 +188,16 @@ class OSS(Optimizer):
self._sync_param_groups() self._sync_param_groups()
# Run the optimizer step on this shard only: # Run the optimizer step on this shard only:
self._free_other_grads()
if closure is not None: if closure is not None:
loss = self.optim.step(closure=closure, **kwargs) # type: ignore loss = self.optim.step(closure=closure, **kwargs) # type: ignore
else: else:
loss = self.optim.step(**kwargs) loss = self.optim.step(**kwargs)
# Depending on the DDP engine used, gradients specific to other ranks may still be loaded
self._free_other_grads()
# Sync all the updated shards in between the ranks # Sync all the updated shards in between the ranks
with torch.no_grad(): self._broadcast_params()
for (
device,
device_params,
) in self.per_device_params.items(): # all the params on this device (inc all ranks)
self._broadcast_params(self._broadcast_buffers[device], device_params)
# Sync hypothethical new results from the wrapped optimizer to the exposed param_groups # Sync hypothethical new results from the wrapped optimizer to the exposed param_groups
self._sync_param_groups(local_to_global=True) self._sync_param_groups(local_to_global=True)
...@@ -489,61 +492,107 @@ class OSS(Optimizer): ...@@ -489,61 +492,107 @@ class OSS(Optimizer):
for t in p["params"]: for t in p["params"]:
t.grad = None t.grad = None
def _broadcast_params(self, buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]]) -> None: def _broadcast_params(self) -> None:
"""Helper function to broadcast all the parameters from a given device""" """Helper function to broadcast all the parameters from a given device"""
buffer_size = buffers[0].numel()
bucket_requests = [] # The unroll callback is called when the broadcast is done.
direct_requests = [] # If this rank is a recipiendary and the call was bucketed, the results from the broadcast are unrolled
# onto the corresponding parameters.
# Bucket and issue all the async calls def get_unroll_callback(src_rank: int, bucket: Bucket) -> Callable:
for (src_rank, params), buffer in zip(enumerate(per_rank_params), buffers): def unroll() -> None:
global_src_rank = self.get_global_rank(self.group, src_rank) if src_rank != self.rank:
for flat in bucket.params:
# Copy small parameters into per-GPU buffers and then async broadcast flat.param.data.copy_(
offset = 0 bucket.buffer[flat.start : flat.stop].view_as(flat.param.data), non_blocking=True
bucket_sent = False
bucket_params = []
# All the params are sorted per rank and per increasing size
for p in params:
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
if not bucket_sent and offset + p.numel() < buffer_size:
end = offset + p.numel()
buffer[offset:end].copy_(p.data.view(-1))
bucket_params.append((p, offset, end))
offset = end
else:
if offset > 0 and not bucket_sent:
bucket_requests.append(
(
dist.broadcast(tensor=buffer, src=global_src_rank, group=self.group, async_op=True),
src_rank,
bucket_params,
)
) )
bucket_sent = True bucket.reset()
direct_requests.append( return unroll
dist.broadcast(tensor=p.data, src=global_src_rank, group=self.group, async_op=True)
)
# Catch a trailing bucket with torch.no_grad():
if not bucket_sent: for (
bucket_requests.append( device,
( device_params,
dist.broadcast(tensor=buffer, src=global_src_rank, group=self.group, async_op=True), ) in self.per_device_params.items(): # all the params on this device (inc all ranks)
src_rank,
bucket_params, buckets = self.buckets[device]
)
)
# Unroll the initial packed small parameters # Bucket and issue all the async calls
for work_handle, src_rank, bucket_params in bucket_requests: for (src_rank, params), bucket in zip(enumerate(device_params), buckets):
work_handle.wait() global_src_rank = self.get_global_rank(self.group, src_rank)
if src_rank != self.rank:
for p, offset, end in bucket_params: for param in params:
p.data.copy_(buffers[src_rank][offset:end].view_as(p.data)) # Bucket broadcast
if self.should_bucket_param[param]:
assert bucket.append(param), "Bucket overflow: max %s - current %s - adding %s" % (
bucket.max_size,
bucket.current_offset,
param.numel(),
)
if bucket.full():
self.work_handles.append(
Workhandle(
handle=dist.broadcast(
tensor=bucket.buffer, src=global_src_rank, group=self.group, async_op=True
),
callback=get_unroll_callback(src_rank, bucket),
)
)
# Direct
else:
self.work_handles.append(
Workhandle(
handle=dist.broadcast(
tensor=param.data, src=global_src_rank, group=self.group, async_op=True
),
callback=None,
)
)
self._consume_work_handles()
def _consume_work_handles(self) -> None:
""" Consume all the futures which are tied to this optimizer's buckets.
We start from the first/older ones, since they are the most likely to be ready and non-blocking
"""
for work_handle in self.work_handles:
work_handle.handle.wait()
if work_handle.callback is not None:
work_handle.callback()
self.work_handles.clear()
def _setup_bucket_strategy(self) -> None:
""" Tag parameters to either bucket them or broadcast/reduce them directly. The parameters are ordered
(smallest first), the bucket will hold the smallest elements, the remaining ones will be directly sent
over the wire.
Generating the partition once and for all allows us to save some time at runtime, and to know when all the
network requests have been issued.
"""
# Unroll all the async work items, just in case for device, per_rank_params in self.per_device_params.items():
_ = list(map(lambda x: x.wait(), direct_requests)) for dst_rank, params in enumerate(per_rank_params):
offset = 0
bucket_size = self.buckets[device][dst_rank].max_size
for param in params:
if (offset + param.numel()) < bucket_size:
# This parameter is small enough to fit in the remaining size of the bucket
self.should_bucket_param[param] = True
offset += param.numel()
else:
# The parameters are sorted by size, so all the following parameters
# will be too big and can be skipped
self.should_bucket_param[param] = False
# Register the max offset for this buffer
self.buckets[device][dst_rank].max_offset = offset
# Determine the max work handles in flight:
# - all the direct reduce/broadcast + 1 bucket
self._max_work_handles = sum(not value for value in self.should_bucket_param.values()) + 1
...@@ -4,13 +4,70 @@ ...@@ -4,13 +4,70 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import io import io
from typing import Any, Dict from typing import Any, Callable, Dict, List, Optional
import torch import torch
from torch._six import container_abcs from torch._six import container_abcs
import torch.distributed as dist import torch.distributed as dist
class Workhandle:
def __init__(self, handle: Any, callback: Optional[Callable]) -> None:
self.handle = handle
self.callback = callback
class FlatParam:
def __init__(self, tensor: torch.Tensor, start: int, stop: int) -> None:
self.param = tensor
self.start = start
self.stop = stop
class Bucket:
"""
Helper class to simplify the handling of broadcast or reduce buckets
"""
def __init__(self, buffer: torch.Tensor) -> None:
# The actual flat tensor
self.buffer = buffer
self.max_size = buffer.numel()
# Handles to the params and their position in this tensor, can be useful for a callback
self.params: List[FlatParam] = []
# Current status for this buffer
self.current_offset = 0
self.max_offset = 0
def reset(self) -> None:
""" empty the bucket """
self.current_offset = 0
self.params.clear()
def append(self, tensor: torch.Tensor, use_gradient: bool = False) -> bool:
""" add a tensor to the bucket """
end = self.current_offset + tensor.numel()
if end > self.max_size:
return False
if use_gradient:
assert tensor.grad is not None
data_source = tensor.grad.data if use_gradient else tensor.data # type: ignore # mypy is drunk
self.buffer[self.current_offset : end].copy_(data_source.view(-1))
self.params.append(FlatParam(tensor=tensor, start=self.current_offset, stop=end))
self.current_offset = end
return True
def full(self) -> bool:
""" is the bucket full ? """
return self.current_offset == self.max_offset
# Credits: classy_vision/generic/distributed_util.py # Credits: classy_vision/generic/distributed_util.py
def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.device) -> Any: def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.device) -> Any:
""" """
......
...@@ -324,7 +324,7 @@ class Tensor: ...@@ -324,7 +324,7 @@ class Tensor:
def coalesce(self) -> Tensor: ... def coalesce(self) -> Tensor: ...
def conj(self) -> Tensor: ... def conj(self) -> Tensor: ...
def contiguous(self) -> Tensor: ... def contiguous(self) -> Tensor: ...
def copy_(self, other: Tensor) -> None: ... def copy_(self, other: Tensor, non_blocking: Optional[_bool]=False) -> None: ...
def cos(self) -> Tensor: ... def cos(self) -> Tensor: ...
def cos_(self) -> Tensor: ... def cos_(self) -> Tensor: ...
def cosh(self) -> Tensor: ... def cosh(self) -> Tensor: ...
......
...@@ -12,3 +12,4 @@ class GradScaler(object): ...@@ -12,3 +12,4 @@ class GradScaler(object):
def _unscale_grads_(self, optimizer: Optimizer, inv_scale: Tensor, found_inf: Tensor, allow_fp16: bool) -> Dict[device, Tensor]:... def _unscale_grads_(self, optimizer: Optimizer, inv_scale: Tensor, found_inf: Tensor, allow_fp16: bool) -> Dict[device, Tensor]:...
def step(self, optimizer: Optimizer, *args: Any, **kwargs: Any): ... def step(self, optimizer: Optimizer, *args: Any, **kwargs: Any): ...
def update(self, new_scale: Optional[float]=None): ... def update(self, new_scale: Optional[float]=None): ...
def unscale_(self, optimizer: Optimizer) -> None: ...
...@@ -28,8 +28,10 @@ class ReduceOp: ...@@ -28,8 +28,10 @@ class ReduceOp:
def get_rank(group: Any = None) -> int: ... def get_rank(group: Any = None) -> int: ...
def get_world_size(group: Any = None) -> int: ... def get_world_size(group: Any = None) -> int: ...
def get_backend(group: Optional[Any] = None) -> Any: ...
def broadcast(tensor: Tensor, src: Any, group: Any, async_op: Any = False): ... def broadcast(tensor: Tensor, src: Any, group: Any, async_op: Any = False): ...
def gather(tensor: Tensor, gather_list: Optional[List[Tensor]], dst: Any, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ...
def reduce(tensor: Tensor, dst: Any, op: Optional[Any]=ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ...
def is_initialized() -> bool: ... def is_initialized() -> bool: ...
......
...@@ -8,7 +8,9 @@ Testing OssDdp class. ...@@ -8,7 +8,9 @@ Testing OssDdp class.
""" """
import tempfile import tempfile
from typing import List
import numpy as np
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -16,18 +18,20 @@ import torch.multiprocessing as mp ...@@ -16,18 +18,20 @@ import torch.multiprocessing as mp
from torch.nn import Linear, Sequential from torch.nn import Linear, Sequential
from fairscale.nn.data_parallel import ShardedDataParallel from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs required") skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs required")
from contextlib import suppress
def test_on_cpu(): def test_step_on_cpu():
run_test(backend=dist.Backend.GLOO, device=torch.device("cpu")) run_test(backend=dist.Backend.GLOO, device=torch.device("cpu"), world_size=4)
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_single_gpu @skip_if_single_gpu
def test_on_gpu(): def test_step_on_gpu():
run_test(backend=dist.Backend.NCCL, device=torch.device("cuda")) run_test(backend=dist.Backend.NCCL, device=torch.device("cuda"))
...@@ -37,46 +41,78 @@ def run_one_step(rank, world_size, backend, device, temp_file_name): ...@@ -37,46 +41,78 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
if device == torch.device("cuda"): if device == torch.device("cuda"):
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
# Any model works. Add one different buffer per rank torch.manual_seed(rank)
model = Sequential(Linear(2, 3)).to(device) np.random.seed(rank)
model.register_buffer("test_buffer", torch.ones((1)) * rank)
def check(broadcast_buffers: bool, grad_accumulation: bool = False) -> None:
def weights_init(m): # Any model works. Add one different buffer per rank
if isinstance(m, Linear): model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
torch.nn.init.constant_(m.weight.data, 1.0) model.register_buffer("test_buffer", torch.ones((1)) * rank)
torch.nn.init.constant_(m.bias.data, 1.0) model.to(device)
model.apply(weights_init) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
model.to(device) ddp_model = ShardedDataParallel(model, optimizer, broadcast_buffers=broadcast_buffers)
ddp = ShardedDataParallel( def check_same_model_params(same_params: bool):
module=model, # Check that all the params are the same on all ranks
optimizer=torch.optim.SGD, # This should be true with and without broadcast_buffers, we don't have any real buffer here
optimizer_params={"lr": 0.01, "momentum": 0.99}, receptacle: List[torch.Tensor] = []
world_size=world_size,
broadcast_buffers=True, if dist.get_backend() != "nccl":
) for pg in optimizer.param_groups:
optimizer = ddp.optimizer for p in pg["params"]:
model = ddp.module # Check the params
receptacle = [p.clone() for _ in range(world_size)] if rank == 0 else []
# Different input per rank, allows for checking that the gradients have been properly reduced dist.gather(p, receptacle, dst=0)
input_tensor = (torch.ones((64, 2)) * rank).to(device) if rank == 0:
output = ddp(input_tensor).abs().sum() for sync_p in receptacle[1:]:
output.backward() if same_params:
ddp.reduce() assert torch.all(torch.eq(receptacle[0], sync_p)), "Models differ in between ranks"
else:
# Check that all the grads have been populated, for the shard assert not torch.all(
for pg in optimizer.optim.param_groups: torch.eq(receptacle[0], sync_p)
for param in pg["params"]: ), "Gradients should not have been synced"
if param.shape == torch.Size([3, 2]):
assert param.grad[0, 0].cpu() == torch.tensor([32.0]) # Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
if param.shape == torch.Size([3]): if broadcast_buffers:
assert param.grad[0].cpu() == torch.tensor([64.0]) for b in ddp_model.buffers():
receptacle = [b.clone() for _ in range(world_size)] if rank == 0 else []
# Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0) dist.gather(b, receptacle, dst=0)
for b in model.buffers(): if rank == 0:
assert b.cpu().item() == 0.0 for sync_b in receptacle[1:]:
if same_params:
assert torch.all(torch.eq(receptacle[0], sync_b)), "Models differ in between ranks"
else:
assert not torch.all(
torch.eq(receptacle[0], sync_b)
), "Gradients should not have been synced"
assert b.cpu().item() == 0.0
# The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
check_same_model_params(same_params=True)
# Optim loop
def closure():
optimizer.zero_grad()
with ddp_model.no_sync() if grad_accumulation else suppress():
input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor).abs().sum()
loss.backward()
return loss
# The models should stay the same in between the ranks
for i in range(5):
_ = optimizer.step(closure=closure)
# when running on cpu/gloo the "nodes" are not really different
same_params = device == torch.device("cpu") or grad_accumulation
check_same_model_params(same_params=same_params)
check(broadcast_buffers=False)
check(broadcast_buffers=True)
check(broadcast_buffers=False, grad_accumulation=True)
check(broadcast_buffers=True, grad_accumulation=True)
dist.destroy_process_group() dist.destroy_process_group()
...@@ -85,33 +121,116 @@ def run_test(backend, device, world_size=2): ...@@ -85,33 +121,116 @@ def run_test(backend, device, world_size=2):
mp.spawn(run_one_step, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True) mp.spawn(run_one_step, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
def run_eval_mode(_unused): def run_test_two_inputs(rank, world_size, backend, device, temp_file_name):
""" Testing eval mode make sure this is no asserts. """ url = "file://" + temp_file_name
dist.init_process_group( dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
init_method=f"file://{tempfile.mkstemp()[1]}", backend=dist.Backend.GLOO, rank=0, world_size=1 if device == torch.device("cuda"):
) torch.cuda.set_device(rank)
model = Sequential(Linear(2, 3), Linear(3, 4))
optimizer_params = {"lr": 0.1, "momentum": 0.99} torch.manual_seed(rank)
ddp = ShardedDataParallel(model, torch.optim.SGD, optimizer_params, 1, broadcast_buffers=False) np.random.seed(rank)
optimizer = ddp.optimizer
class _DoubleInput(torch.nn.Module):
ddp.eval() def __init__(self):
for _ in range(5): super().__init__()
input_tensor = torch.rand((64, 2)) self.mlp = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
output = ddp(input_tensor)
def forward(self, x, y):
ddp.train() x1 = self.mlp(x)
try: x2 = self.mlp(y)
for _ in range(5): return torch.cat((x1, x2), dim=1)
input_tensor = torch.rand((64, 2))
output = ddp(input_tensor) model = _DoubleInput().to(device)
except RuntimeError:
pass optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
else: ddp_model = ShardedDataParallel(model, optimizer)
assert False, "Multiple forward passes on training mode should not pass"
# Optim loop
def closure():
optimizer.zero_grad()
input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor, input_tensor).abs().sum()
loss.backward()
return loss
# The models should stay the same in between the ranks
for i in range(5):
_ = optimizer.step(closure=closure)
dist.destroy_process_group() dist.destroy_process_group()
def test_eval_mode(): def test_inputs():
mp.spawn(run_eval_mode, args=(), join=True) # Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size = 2
backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cpu"
mp.spawn(run_test_two_inputs, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
def test_ddp_attributes():
# Check that ShardedDDP exposes the same attributes as Pytorch's DDP
# - is multi_device_module
# - device_type
url = "file://" + tempfile.mkstemp()[1]
dist.init_process_group(init_method=url, backend="gloo", rank=0, world_size=1)
model = Sequential(Linear(2, 3), Linear(3, 3))
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
assert hasattr(ddp_model, "is_multi_device_module")
assert hasattr(ddp_model, "device_type")
dist.destroy_process_group()
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
if device == torch.device("cuda"):
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
class _DoubleInput(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
def forward(self, x, y):
x1 = self.mlp(x)
x2 = self.mlp(y)
return torch.cat((x1, x2), dim=1)
model = _DoubleInput().to(device)
parameters = list(model.parameters())
optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=0.01, momentum=0.99)
optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2])
# Optim loop
def closure():
optimizer.zero_grad()
input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor, input_tensor).abs().sum()
loss.backward()
return loss
# The models should stay the same in between the ranks
for i in range(5):
_ = optimizer.step(closure=closure)
dist.destroy_process_group()
def test_two_optimizers():
# Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size = 2
backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cpu"
mp.spawn(run_test_two_inputs, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment