Unverified Commit ad933b34 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] ShardedDataParallel with autoreduce (#157)

* rewrite using autograd and Variable execution queue to make the reduce automatic
* share buckets with OSS to remove duplication
* some speed still likely on the table since the speed vs. bucketing does not match expectations, could be a follow up
parent 35d4129f
......@@ -124,10 +124,12 @@ run_oss_benchmark: &run_oss_benchmark
python benchmarks/oss.py --check_regression --world_size 4 --reference_speed 760 --reference_memory 1120 --reference_loss 0.023
run_oss_gloo: &run_oss_gloo
- run:
name: Run OSS with Gloo
command: |
python benchmarks/oss.py --gloo --optim_type oss_ddp --epochs 3
- run:
name: Run OSS with Gloo
command: |
python benchmarks/oss.py --gloo --optim_type oss_ddp --epochs 2
python benchmarks/oss.py --gloo --optim_type oss_sharded_ddp --epochs 2
run_oss_amp: &run_oss_amp
- run:
......
......@@ -97,19 +97,10 @@ def train(
scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else ShardedGradScaler()) if args.amp else None
if optim_type == OptimType.oss_sharded_ddp:
model = ShardedDDP(
model,
optimizer=OPTIM,
optimizer_params={"lr": 1e-4, "momentum": 0.9},
world_size=args.world_size,
broadcast_buffers=True,
)
optimizer = model.sharded_optimizer
optimizer = OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
model = ShardedDDP(model, optimizer)
else:
if args.cpu:
device_ids = None
else:
device_ids = [rank]
device_ids = None if args.cpu else [rank]
model = DDP(model, device_ids=device_ids, find_unused_parameters=False) # type: ignore
optimizer = (
OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
......@@ -120,6 +111,7 @@ def train(
# Reset the memory use counter
if not args.cpu:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(rank)
torch.cuda.synchronize(rank)
......@@ -159,9 +151,6 @@ def train(
loss = loss_fn(outputs, data["label"])
loss.backward()
if optim_type == OptimType.oss_sharded_ddp:
model.reduce()
if args.debug and rank == 0 and next(model.parameters()).grad is not None:
logging.debug(
"after BW: param {} -- grad {}".format(
......
......@@ -8,3 +8,4 @@ API Reference
optim/oss
optim/grad_scaler
nn/pipe
nn/sharded_ddp
ShardedDataParallel
====================
.. autoclass:: fairscale.nn.ShardedDataParallel
:members:
:undoc-members:
......@@ -3,7 +3,8 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from .data_parallel import ShardedDataParallel
from .moe import MOELayer, Top2Gate
from .pipe import LazyModule, Pipe, PipeRPCWrapper
__all__ = ["Pipe", "PipeRPCWrapper", "Top2Gate", "LazyModule"]
__all__ = ["Pipe", "PipeRPCWrapper", "Top2Gate", "LazyModule", "ShardedDataParallel"]
This diff is collapsed.
......@@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Dict
import torch
from torch.cuda.amp import GradScaler as TorchGradScaler
......@@ -32,15 +32,15 @@ class ShardedGradScaler(TorchGradScaler):
def __init__(self) -> None:
super().__init__()
def step(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> Optional[float]:
def unscale_(self, optimizer: Optimizer) -> None:
assert isinstance(optimizer, OSS), "ShardedGradScaler is to be used in combination with a sharded optimizer"
# Re-use the GradSCaler machinery, but make sure that the status is sync'ed in between the ranks
# Call the upstream unscale_ method which will only act on this rank's gradients
super().unscale_(optimizer)
# Synchronize the detected inf across the ranks
optimizer_state = self._per_optimizer_states[id(optimizer)]
handles = [dist.all_reduce(v, async_op=True) for v in optimizer_state["found_inf_per_device"].values()]
# Make sure that the calls are done before moving out
_ = list(map(lambda x: x.wait(), handles))
# Call Torch's GradScaler in turn, states have been synchronized across ranks
return super().step(optimizer, *args, **kwargs)
......@@ -16,7 +16,7 @@ import torch.distributed as dist
from torch.nn import Parameter
from torch.optim import SGD, Optimizer
from .utils import broadcast_object, recursive_copy_to_device
from .utils import Bucket, Workhandle, broadcast_object, recursive_copy_to_device
__all__ = ["OSS"]
......@@ -73,7 +73,7 @@ class OSS(Optimizer):
super().__init__(params, default)
self.in_super_constructor = False
# Partition information. lazy evaluation, computed if requested
# Partition information. lazy evaluation, computed when requested
self._per_device_params: Dict[torch.device, List[List[Parameter]]] = OrderedDict() # device, rank, params
self._param_rank: Dict[torch.Tensor, int] = {}
self._partition_parameters: List[List[dict]] = []
......@@ -88,22 +88,26 @@ class OSS(Optimizer):
# - Sync local and global param_groups keys
for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
for k, v in local_group.items():
if k != "params":
global_group[k] = v
for key, value in local_group.items():
if key != "params":
global_group[key] = value
# Optional consolidated optimizer state
self._all_states: List[Dict[str, Any]] = []
# Current default device is set by the parameters allocated to this rank
self._device = self.partition_parameters()[self.rank][0]["params"][0].device
self._broadcast_buffers: Dict[torch.device, List[torch.Tensor]] = {}
self.buckets: Dict[torch.device, List[Bucket]] = {}
for device, per_device in self.per_device_params.items():
# Allocate one buffer per rank and per device to group the small parameters
self._broadcast_buffers[device] = [
torch.zeros(broadcast_buffer_size, dtype=per_device[0][0].dtype, device=device)
self.buckets[device] = [
Bucket(buffer=torch.zeros(broadcast_buffer_size, dtype=per_device[0][0].dtype, device=device))
for _ in range(len(per_device))
]
self.should_bucket_param: Dict[torch.Tensor, bool] = {}
self.work_handles: List[Workhandle] = []
self._max_work_handles = -1
self._setup_bucket_strategy()
# Partition helpers
def partition_parameters(self) -> List[List[dict]]:
......@@ -150,9 +154,9 @@ class OSS(Optimizer):
self._per_device_params[device][self.param_to_rank[param]] += [param]
# Sort param_lists by size
for k in self._per_device_params.keys():
for r in self._per_device_params[k]:
r.sort(key=lambda x: x.numel())
for device in self._per_device_params.keys():
for rank_params in self._per_device_params[device]:
rank_params.sort(key=lambda x: x.numel())
return self._per_device_params
......@@ -164,6 +168,9 @@ class OSS(Optimizer):
for param_group in param_groups:
for param in param_group["params"]:
self._param_rank[param] = rank
logging.debug("ZeRO: Parameters dispatched to ranks %s " % list(self._param_rank.values()))
return self._param_rank
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
......@@ -181,20 +188,16 @@ class OSS(Optimizer):
self._sync_param_groups()
# Run the optimizer step on this shard only:
self._free_other_grads()
if closure is not None:
loss = self.optim.step(closure=closure, **kwargs) # type: ignore
else:
loss = self.optim.step(**kwargs)
# Depending on the DDP engine used, gradients specific to other ranks may still be loaded
self._free_other_grads()
# Sync all the updated shards in between the ranks
with torch.no_grad():
for (
device,
device_params,
) in self.per_device_params.items(): # all the params on this device (inc all ranks)
self._broadcast_params(self._broadcast_buffers[device], device_params)
self._broadcast_params()
# Sync hypothethical new results from the wrapped optimizer to the exposed param_groups
self._sync_param_groups(local_to_global=True)
......@@ -489,61 +492,107 @@ class OSS(Optimizer):
for t in p["params"]:
t.grad = None
def _broadcast_params(self, buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]]) -> None:
def _broadcast_params(self) -> None:
"""Helper function to broadcast all the parameters from a given device"""
buffer_size = buffers[0].numel()
bucket_requests = []
direct_requests = []
# Bucket and issue all the async calls
for (src_rank, params), buffer in zip(enumerate(per_rank_params), buffers):
global_src_rank = self.get_global_rank(self.group, src_rank)
# Copy small parameters into per-GPU buffers and then async broadcast
offset = 0
bucket_sent = False
bucket_params = []
# All the params are sorted per rank and per increasing size
for p in params:
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
if not bucket_sent and offset + p.numel() < buffer_size:
end = offset + p.numel()
buffer[offset:end].copy_(p.data.view(-1))
bucket_params.append((p, offset, end))
offset = end
else:
if offset > 0 and not bucket_sent:
bucket_requests.append(
(
dist.broadcast(tensor=buffer, src=global_src_rank, group=self.group, async_op=True),
src_rank,
bucket_params,
)
# The unroll callback is called when the broadcast is done.
# If this rank is a recipiendary and the call was bucketed, the results from the broadcast are unrolled
# onto the corresponding parameters.
def get_unroll_callback(src_rank: int, bucket: Bucket) -> Callable:
def unroll() -> None:
if src_rank != self.rank:
for flat in bucket.params:
flat.param.data.copy_(
bucket.buffer[flat.start : flat.stop].view_as(flat.param.data), non_blocking=True
)
bucket_sent = True
bucket.reset()
direct_requests.append(
dist.broadcast(tensor=p.data, src=global_src_rank, group=self.group, async_op=True)
)
return unroll
# Catch a trailing bucket
if not bucket_sent:
bucket_requests.append(
(
dist.broadcast(tensor=buffer, src=global_src_rank, group=self.group, async_op=True),
src_rank,
bucket_params,
)
)
with torch.no_grad():
for (
device,
device_params,
) in self.per_device_params.items(): # all the params on this device (inc all ranks)
buckets = self.buckets[device]
# Unroll the initial packed small parameters
for work_handle, src_rank, bucket_params in bucket_requests:
work_handle.wait()
if src_rank != self.rank:
for p, offset, end in bucket_params:
p.data.copy_(buffers[src_rank][offset:end].view_as(p.data))
# Bucket and issue all the async calls
for (src_rank, params), bucket in zip(enumerate(device_params), buckets):
global_src_rank = self.get_global_rank(self.group, src_rank)
for param in params:
# Bucket broadcast
if self.should_bucket_param[param]:
assert bucket.append(param), "Bucket overflow: max %s - current %s - adding %s" % (
bucket.max_size,
bucket.current_offset,
param.numel(),
)
if bucket.full():
self.work_handles.append(
Workhandle(
handle=dist.broadcast(
tensor=bucket.buffer, src=global_src_rank, group=self.group, async_op=True
),
callback=get_unroll_callback(src_rank, bucket),
)
)
# Direct
else:
self.work_handles.append(
Workhandle(
handle=dist.broadcast(
tensor=param.data, src=global_src_rank, group=self.group, async_op=True
),
callback=None,
)
)
self._consume_work_handles()
def _consume_work_handles(self) -> None:
""" Consume all the futures which are tied to this optimizer's buckets.
We start from the first/older ones, since they are the most likely to be ready and non-blocking
"""
for work_handle in self.work_handles:
work_handle.handle.wait()
if work_handle.callback is not None:
work_handle.callback()
self.work_handles.clear()
def _setup_bucket_strategy(self) -> None:
""" Tag parameters to either bucket them or broadcast/reduce them directly. The parameters are ordered
(smallest first), the bucket will hold the smallest elements, the remaining ones will be directly sent
over the wire.
Generating the partition once and for all allows us to save some time at runtime, and to know when all the
network requests have been issued.
"""
# Unroll all the async work items, just in case
_ = list(map(lambda x: x.wait(), direct_requests))
for device, per_rank_params in self.per_device_params.items():
for dst_rank, params in enumerate(per_rank_params):
offset = 0
bucket_size = self.buckets[device][dst_rank].max_size
for param in params:
if (offset + param.numel()) < bucket_size:
# This parameter is small enough to fit in the remaining size of the bucket
self.should_bucket_param[param] = True
offset += param.numel()
else:
# The parameters are sorted by size, so all the following parameters
# will be too big and can be skipped
self.should_bucket_param[param] = False
# Register the max offset for this buffer
self.buckets[device][dst_rank].max_offset = offset
# Determine the max work handles in flight:
# - all the direct reduce/broadcast + 1 bucket
self._max_work_handles = sum(not value for value in self.should_bucket_param.values()) + 1
......@@ -4,13 +4,70 @@
# LICENSE file in the root directory of this source tree.
import io
from typing import Any, Dict
from typing import Any, Callable, Dict, List, Optional
import torch
from torch._six import container_abcs
import torch.distributed as dist
class Workhandle:
def __init__(self, handle: Any, callback: Optional[Callable]) -> None:
self.handle = handle
self.callback = callback
class FlatParam:
def __init__(self, tensor: torch.Tensor, start: int, stop: int) -> None:
self.param = tensor
self.start = start
self.stop = stop
class Bucket:
"""
Helper class to simplify the handling of broadcast or reduce buckets
"""
def __init__(self, buffer: torch.Tensor) -> None:
# The actual flat tensor
self.buffer = buffer
self.max_size = buffer.numel()
# Handles to the params and their position in this tensor, can be useful for a callback
self.params: List[FlatParam] = []
# Current status for this buffer
self.current_offset = 0
self.max_offset = 0
def reset(self) -> None:
""" empty the bucket """
self.current_offset = 0
self.params.clear()
def append(self, tensor: torch.Tensor, use_gradient: bool = False) -> bool:
""" add a tensor to the bucket """
end = self.current_offset + tensor.numel()
if end > self.max_size:
return False
if use_gradient:
assert tensor.grad is not None
data_source = tensor.grad.data if use_gradient else tensor.data # type: ignore # mypy is drunk
self.buffer[self.current_offset : end].copy_(data_source.view(-1))
self.params.append(FlatParam(tensor=tensor, start=self.current_offset, stop=end))
self.current_offset = end
return True
def full(self) -> bool:
""" is the bucket full ? """
return self.current_offset == self.max_offset
# Credits: classy_vision/generic/distributed_util.py
def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.device) -> Any:
"""
......
......@@ -324,7 +324,7 @@ class Tensor:
def coalesce(self) -> Tensor: ...
def conj(self) -> Tensor: ...
def contiguous(self) -> Tensor: ...
def copy_(self, other: Tensor) -> None: ...
def copy_(self, other: Tensor, non_blocking: Optional[_bool]=False) -> None: ...
def cos(self) -> Tensor: ...
def cos_(self) -> Tensor: ...
def cosh(self) -> Tensor: ...
......
......@@ -12,3 +12,4 @@ class GradScaler(object):
def _unscale_grads_(self, optimizer: Optimizer, inv_scale: Tensor, found_inf: Tensor, allow_fp16: bool) -> Dict[device, Tensor]:...
def step(self, optimizer: Optimizer, *args: Any, **kwargs: Any): ...
def update(self, new_scale: Optional[float]=None): ...
def unscale_(self, optimizer: Optimizer) -> None: ...
......@@ -28,8 +28,10 @@ class ReduceOp:
def get_rank(group: Any = None) -> int: ...
def get_world_size(group: Any = None) -> int: ...
def get_backend(group: Optional[Any] = None) -> Any: ...
def broadcast(tensor: Tensor, src: Any, group: Any, async_op: Any = False): ...
def gather(tensor: Tensor, gather_list: Optional[List[Tensor]], dst: Any, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ...
def reduce(tensor: Tensor, dst: Any, op: Optional[Any]=ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ...
def is_initialized() -> bool: ...
......
......@@ -8,7 +8,9 @@ Testing OssDdp class.
"""
import tempfile
from typing import List
import numpy as np
import pytest
import torch
import torch.distributed as dist
......@@ -16,18 +18,20 @@ import torch.multiprocessing as mp
from torch.nn import Linear, Sequential
from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs required")
from contextlib import suppress
def test_on_cpu():
run_test(backend=dist.Backend.GLOO, device=torch.device("cpu"))
def test_step_on_cpu():
run_test(backend=dist.Backend.GLOO, device=torch.device("cpu"), world_size=4)
@skip_if_no_cuda
@skip_if_single_gpu
def test_on_gpu():
def test_step_on_gpu():
run_test(backend=dist.Backend.NCCL, device=torch.device("cuda"))
......@@ -37,46 +41,78 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
if device == torch.device("cuda"):
torch.cuda.set_device(rank)
# Any model works. Add one different buffer per rank
model = Sequential(Linear(2, 3)).to(device)
model.register_buffer("test_buffer", torch.ones((1)) * rank)
def weights_init(m):
if isinstance(m, Linear):
torch.nn.init.constant_(m.weight.data, 1.0)
torch.nn.init.constant_(m.bias.data, 1.0)
model.apply(weights_init)
model.to(device)
ddp = ShardedDataParallel(
module=model,
optimizer=torch.optim.SGD,
optimizer_params={"lr": 0.01, "momentum": 0.99},
world_size=world_size,
broadcast_buffers=True,
)
optimizer = ddp.optimizer
model = ddp.module
# Different input per rank, allows for checking that the gradients have been properly reduced
input_tensor = (torch.ones((64, 2)) * rank).to(device)
output = ddp(input_tensor).abs().sum()
output.backward()
ddp.reduce()
# Check that all the grads have been populated, for the shard
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
if param.shape == torch.Size([3, 2]):
assert param.grad[0, 0].cpu() == torch.tensor([32.0])
if param.shape == torch.Size([3]):
assert param.grad[0].cpu() == torch.tensor([64.0])
# Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
for b in model.buffers():
assert b.cpu().item() == 0.0
torch.manual_seed(rank)
np.random.seed(rank)
def check(broadcast_buffers: bool, grad_accumulation: bool = False) -> None:
# Any model works. Add one different buffer per rank
model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer, broadcast_buffers=broadcast_buffers)
def check_same_model_params(same_params: bool):
# Check that all the params are the same on all ranks
# This should be true with and without broadcast_buffers, we don't have any real buffer here
receptacle: List[torch.Tensor] = []
if dist.get_backend() != "nccl":
for pg in optimizer.param_groups:
for p in pg["params"]:
# Check the params
receptacle = [p.clone() for _ in range(world_size)] if rank == 0 else []
dist.gather(p, receptacle, dst=0)
if rank == 0:
for sync_p in receptacle[1:]:
if same_params:
assert torch.all(torch.eq(receptacle[0], sync_p)), "Models differ in between ranks"
else:
assert not torch.all(
torch.eq(receptacle[0], sync_p)
), "Gradients should not have been synced"
# Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
if broadcast_buffers:
for b in ddp_model.buffers():
receptacle = [b.clone() for _ in range(world_size)] if rank == 0 else []
dist.gather(b, receptacle, dst=0)
if rank == 0:
for sync_b in receptacle[1:]:
if same_params:
assert torch.all(torch.eq(receptacle[0], sync_b)), "Models differ in between ranks"
else:
assert not torch.all(
torch.eq(receptacle[0], sync_b)
), "Gradients should not have been synced"
assert b.cpu().item() == 0.0
# The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
check_same_model_params(same_params=True)
# Optim loop
def closure():
optimizer.zero_grad()
with ddp_model.no_sync() if grad_accumulation else suppress():
input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor).abs().sum()
loss.backward()
return loss
# The models should stay the same in between the ranks
for i in range(5):
_ = optimizer.step(closure=closure)
# when running on cpu/gloo the "nodes" are not really different
same_params = device == torch.device("cpu") or grad_accumulation
check_same_model_params(same_params=same_params)
check(broadcast_buffers=False)
check(broadcast_buffers=True)
check(broadcast_buffers=False, grad_accumulation=True)
check(broadcast_buffers=True, grad_accumulation=True)
dist.destroy_process_group()
......@@ -85,33 +121,116 @@ def run_test(backend, device, world_size=2):
mp.spawn(run_one_step, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
def run_eval_mode(_unused):
""" Testing eval mode make sure this is no asserts. """
dist.init_process_group(
init_method=f"file://{tempfile.mkstemp()[1]}", backend=dist.Backend.GLOO, rank=0, world_size=1
)
model = Sequential(Linear(2, 3), Linear(3, 4))
optimizer_params = {"lr": 0.1, "momentum": 0.99}
ddp = ShardedDataParallel(model, torch.optim.SGD, optimizer_params, 1, broadcast_buffers=False)
optimizer = ddp.optimizer
ddp.eval()
for _ in range(5):
input_tensor = torch.rand((64, 2))
output = ddp(input_tensor)
ddp.train()
try:
for _ in range(5):
input_tensor = torch.rand((64, 2))
output = ddp(input_tensor)
except RuntimeError:
pass
else:
assert False, "Multiple forward passes on training mode should not pass"
def run_test_two_inputs(rank, world_size, backend, device, temp_file_name):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
if device == torch.device("cuda"):
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
class _DoubleInput(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
def forward(self, x, y):
x1 = self.mlp(x)
x2 = self.mlp(y)
return torch.cat((x1, x2), dim=1)
model = _DoubleInput().to(device)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
# Optim loop
def closure():
optimizer.zero_grad()
input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor, input_tensor).abs().sum()
loss.backward()
return loss
# The models should stay the same in between the ranks
for i in range(5):
_ = optimizer.step(closure=closure)
dist.destroy_process_group()
def test_eval_mode():
mp.spawn(run_eval_mode, args=(), join=True)
def test_inputs():
# Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size = 2
backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cpu"
mp.spawn(run_test_two_inputs, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
def test_ddp_attributes():
# Check that ShardedDDP exposes the same attributes as Pytorch's DDP
# - is multi_device_module
# - device_type
url = "file://" + tempfile.mkstemp()[1]
dist.init_process_group(init_method=url, backend="gloo", rank=0, world_size=1)
model = Sequential(Linear(2, 3), Linear(3, 3))
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
assert hasattr(ddp_model, "is_multi_device_module")
assert hasattr(ddp_model, "device_type")
dist.destroy_process_group()
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
if device == torch.device("cuda"):
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
class _DoubleInput(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
def forward(self, x, y):
x1 = self.mlp(x)
x2 = self.mlp(y)
return torch.cat((x1, x2), dim=1)
model = _DoubleInput().to(device)
parameters = list(model.parameters())
optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=0.01, momentum=0.99)
optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2])
# Optim loop
def closure():
optimizer.zero_grad()
input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor, input_tensor).abs().sum()
loss.backward()
return loss
# The models should stay the same in between the ranks
for i in range(5):
_ = optimizer.step(closure=closure)
dist.destroy_process_group()
def test_two_optimizers():
# Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size = 2
backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cpu"
mp.spawn(run_test_two_inputs, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment