"examples/vscode:/vscode.git/clone" did not exist on "d6aeaa74b717846fbbafa6975f41a09b0d666cae"
Unverified Commit ad933b34 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] ShardedDataParallel with autoreduce (#157)

* rewrite using autograd and Variable execution queue to make the reduce automatic
* share buckets with OSS to remove duplication
* some speed still likely on the table since the speed vs. bucketing does not match expectations, could be a follow up
parent 35d4129f
...@@ -124,10 +124,12 @@ run_oss_benchmark: &run_oss_benchmark ...@@ -124,10 +124,12 @@ run_oss_benchmark: &run_oss_benchmark
python benchmarks/oss.py --check_regression --world_size 4 --reference_speed 760 --reference_memory 1120 --reference_loss 0.023 python benchmarks/oss.py --check_regression --world_size 4 --reference_speed 760 --reference_memory 1120 --reference_loss 0.023
run_oss_gloo: &run_oss_gloo run_oss_gloo: &run_oss_gloo
- run: - run:
name: Run OSS with Gloo name: Run OSS with Gloo
command: | command: |
python benchmarks/oss.py --gloo --optim_type oss_ddp --epochs 3 python benchmarks/oss.py --gloo --optim_type oss_ddp --epochs 2
python benchmarks/oss.py --gloo --optim_type oss_sharded_ddp --epochs 2
run_oss_amp: &run_oss_amp run_oss_amp: &run_oss_amp
- run: - run:
......
...@@ -97,19 +97,10 @@ def train( ...@@ -97,19 +97,10 @@ def train(
scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else ShardedGradScaler()) if args.amp else None scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else ShardedGradScaler()) if args.amp else None
if optim_type == OptimType.oss_sharded_ddp: if optim_type == OptimType.oss_sharded_ddp:
model = ShardedDDP( optimizer = OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
model, model = ShardedDDP(model, optimizer)
optimizer=OPTIM,
optimizer_params={"lr": 1e-4, "momentum": 0.9},
world_size=args.world_size,
broadcast_buffers=True,
)
optimizer = model.sharded_optimizer
else: else:
if args.cpu: device_ids = None if args.cpu else [rank]
device_ids = None
else:
device_ids = [rank]
model = DDP(model, device_ids=device_ids, find_unused_parameters=False) # type: ignore model = DDP(model, device_ids=device_ids, find_unused_parameters=False) # type: ignore
optimizer = ( optimizer = (
OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9) OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
...@@ -120,6 +111,7 @@ def train( ...@@ -120,6 +111,7 @@ def train(
# Reset the memory use counter # Reset the memory use counter
if not args.cpu: if not args.cpu:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(rank) torch.cuda.reset_peak_memory_stats(rank)
torch.cuda.synchronize(rank) torch.cuda.synchronize(rank)
...@@ -159,9 +151,6 @@ def train( ...@@ -159,9 +151,6 @@ def train(
loss = loss_fn(outputs, data["label"]) loss = loss_fn(outputs, data["label"])
loss.backward() loss.backward()
if optim_type == OptimType.oss_sharded_ddp:
model.reduce()
if args.debug and rank == 0 and next(model.parameters()).grad is not None: if args.debug and rank == 0 and next(model.parameters()).grad is not None:
logging.debug( logging.debug(
"after BW: param {} -- grad {}".format( "after BW: param {} -- grad {}".format(
......
...@@ -8,3 +8,4 @@ API Reference ...@@ -8,3 +8,4 @@ API Reference
optim/oss optim/oss
optim/grad_scaler optim/grad_scaler
nn/pipe nn/pipe
nn/sharded_ddp
ShardedDataParallel
====================
.. autoclass:: fairscale.nn.ShardedDataParallel
:members:
:undoc-members:
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .data_parallel import ShardedDataParallel
from .moe import MOELayer, Top2Gate from .moe import MOELayer, Top2Gate
from .pipe import LazyModule, Pipe, PipeRPCWrapper from .pipe import LazyModule, Pipe, PipeRPCWrapper
__all__ = ["Pipe", "PipeRPCWrapper", "Top2Gate", "LazyModule"] __all__ = ["Pipe", "PipeRPCWrapper", "Top2Gate", "LazyModule", "ShardedDataParallel"]
...@@ -4,234 +4,283 @@ ...@@ -4,234 +4,283 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
""" """
A distributed data parallel class that works with OSS optimizer. A nn.Module wrapper to go with a Sharded Optimizer in order to handle targeted gradient
reduction automatically.
Adopted from LegacyDistributedDataParallel module from fairseq.
""" """
from contextlib import contextmanager import contextlib
import copy from itertools import chain
from typing import Any, Dict, Generator, List, Type, cast import logging
from typing import Any, Callable, Generator, List, Tuple, Union
import torch import torch
from torch import Tensor, nn from torch import nn
from torch.autograd import Variable
import torch.distributed as dist import torch.distributed as dist
from torch.nn import Parameter from torch.nn import Parameter
from fairscale.optim import OSS from fairscale.optim import OSS
from fairscale.optim.utils import Workhandle
class ShardedDataParallel(nn.Module): class ShardedDataParallel(nn.Module):
"""Implements distributed data parallel training with optimizer state sharding. """
Wrap the model, and reduce the gradients to the right rank during the backward pass.
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
This version uses a c10d process group for communication and optionally - the partition is given by the sharded optimizer
broadcast buffers. - wrap the base model with a model which knows where to reduce each gradient
- add an autograd function which calls the model grad dispatch on the way back
Args:
module (~torch.nn.Module): module to be parallelized Args:
optimizer (~torch.optim.Optimizer): optimizer to be used for training module (nn.Module):
optimizer_params(Dict): extra parameters for the optimizer model to be wrapped
world_size (int): number of parallel workers sharded_optimizer (OSS, or list of OSS):
broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of the sharded optimizer(s) which will decide the gradient partitioning
the module at beginning of the forward function. (default: ``True``)
process_group (optional): the c10d process group to be used for Keyword Args:
distributed gradient reduction. If None, the default WORLD process group process_group (torch.nn.Optimizer):
will be used. Optimizer to shard (default: SGD)
buffer_size (int, optional): number of elements to buffer before process_group (group):
performing reduce (default: 512k). Used to reduce multiple small torch.distributed group (default: group.WORLD)
params to avoid communication overhead. broadcast_buffers (bool):
Whether to additionally broadcast model buffers in between ranks at the beginning of each forward pass.
Same setting as in Pytorch DDP, this is in addition to the broadcast and reduction of the model parameters.
sync_models_at_startup (bool):
Synchronize the models in between the ranks when starting up. Not needed if each rank has the same seed,
or the training restarts from a saved state
""" """
def __init__( def __init__(
self, self,
module: nn.Module, module: nn.Module,
optimizer: Type[torch.optim.Optimizer], sharded_optimizer: Union[OSS, List[OSS]],
optimizer_params: Dict[str, Any],
world_size: int,
broadcast_buffers: bool,
process_group: Any = None, process_group: Any = None,
buffer_size: int = 2 ** 19, broadcast_buffers: bool = True,
sync_models_at_startup: bool = True,
): ):
super().__init__() super().__init__()
self.module = module self.module = module
self.world_size = world_size self.sharded_optimizers = [sharded_optimizer] if isinstance(sharded_optimizer, OSS) else sharded_optimizer
self.enable_broadcast_buffers = broadcast_buffers
# Handle a no_sync() context which prevents the gradient synchronization,
# accumulate in place
self.should_accumulate_grads = False
# Communication related attributes
self.process_group = process_group if process_group is not None else dist.group.WORLD self.process_group = process_group if process_group is not None else dist.group.WORLD
self.world_size = dist.get_world_size(self.process_group)
self.reference_global_rank = OSS.get_global_rank(self.process_group, 0) # picking rank 0 as the reference
self.rank = dist.get_rank(self.process_group) self.rank = dist.get_rank(self.process_group)
self.broadcast_buffers = broadcast_buffers self.global_rank = OSS.get_global_rank(self.process_group, self.rank)
self.authoritative_rank = 0
# Expose some of the PytorchDDP attributes, some frameworks rely on them.
# See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel
# device_id related logic is not present, this is not handled
devices = {p.device for p in self.module.parameters()}
self.is_multi_device_module = len(devices) > 1
self.device = list(devices)[0]
distinct_device_types = {p.device.type for p in self.module.parameters()}
assert len(distinct_device_types) == 1, (
"ShardedDataParallel's input module must be on "
"the same type of devices, but input module parameters are located on {} different device types."
).format(distinct_device_types)
self.device_type = list(distinct_device_types)[0]
# Scafolding to be able to reduce the grads during the BW pass
# several optimizers can be present each working on seperate parameter sets,
# we build an iterator which goes through all the parameters involved globally
self._param_iterator = chain(*[optim.should_bucket_param.keys() for optim in self.sharded_optimizers])
self._grad_to_be_reduced = [True for _ in self._param_iterator]
self._grad_accs: List[Callable] = []
self._setup_backward_hooks()
# Make sure that all ranks start with the same model
if sync_models_at_startup:
self._sync_params_and_buffers()
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
"""
Module forward pass, handles any DDP-specific work in the background. Primes the
backward pass for gradient reduction to the proper ranks.
"""
if self.enable_broadcast_buffers:
# NCCL communications are on a different stream, needs to be blocking
# for the subsequent FW to be correct
self.sync_buffers(blocking=True)
# Flag used to make sure we only reduce gradients one time in the execution engine # Normal FW on the base model
self.need_reduction = False return self.module(*inputs, **kwargs)
# We can also forcibly accumulate grads locally and only do the def reduce(self) -> None:
# gradients-reduce at some later time """ .. deprecated:: 0.0.4
self.accumulate_grads = False
# Build the sharded optimizer This does not need to be called, the gradient reduction is done automatically during the BW pass
self.sharded_optimizer = OSS(self.module.parameters(), optim=optimizer, group=process_group, **optimizer_params) """
logging.warning("This is not useful anymore, gradients have been reduced automatically with the backward pass")
def _sync_params_and_buffers(self) -> None:
"""
Sync the complete model states in between the ranks
"""
with torch.no_grad():
work_handles = [
dist.broadcast(t, src=self.reference_global_rank, group=self.process_group, async_op=True)
for t in self.module.state_dict().values()
]
# Allocate reduce buffers _ = list(map(lambda x: x.wait(), work_handles))
# - Never use a bigger buffer than the number of model params
buffer_size = min(buffer_size, sum(p.numel() for p in self.module.parameters()))
self._reduce_buffers: Dict[torch.device, List[torch.Tensor]] = {}
# - One buffer per rank per device def sync_buffers(self, blocking: bool = False) -> None:
for device, per_device in self.sharded_optimizer.per_device_params.items(): """
buffer_dtype = per_device[0][0].dtype Sync all the param buffers in between ranks (including for instance batch norm statistics).
self._reduce_buffers[device] = [ """
torch.zeros(buffer_size, dtype=buffer_dtype, device=device) for _ in range(len(per_device)) with torch.no_grad():
work_handles = [
dist.broadcast(buffer.data, self.reference_global_rank, self.process_group, async_op=True)
for buffer in self.module.buffers(recurse=True)
] ]
# Sanity checks if blocking:
assert len(self.sharded_optimizer.param_to_rank) == len( _ = list(map(lambda x: x.wait(), work_handles))
list(self.module.parameters())
), "number of params do not match" @contextlib.contextmanager
for param in self.module.parameters():
assert param in self.sharded_optimizer.param_to_rank, f"{param} not in the optimizer"
def __getstate__(self) -> Dict:
attrs = copy.copy(self.__dict__)
return attrs
@property
def optimizer(self) -> torch.optim.Optimizer:
return self.sharded_optimizer
def train(self, mode: bool = True) -> "ShardedDataParallel":
pre_mode = self.module.training
self.module.train(mode)
if self.module.training:
assert not self.need_reduction or pre_mode, "incorrect state transition"
else:
assert not self.need_reduction, "try to enter eval with grads unreduced"
return self
@contextmanager
def no_sync(self) -> Generator: def no_sync(self) -> Generator:
"""A context manager to disable gradient synchronization.""" """A context manager to disable gradient synchronization."""
old_accumulate_grads = self.accumulate_grads old_should_accumulate_grads = self.should_accumulate_grads
self.accumulate_grads = True self.should_accumulate_grads = True
yield yield
self.accumulate_grads = old_accumulate_grads self.should_accumulate_grads = old_should_accumulate_grads
def forward(self, *inputs: Any, **kwargs: Any) -> Tensor: def _find_rank(self, param: Parameter) -> Tuple[OSS, int]:
if self.module.training: """ Look up where this parameter belongs to """
if self.need_reduction: for optim in self.sharded_optimizers:
raise RuntimeError("OssDdp requires explicit reduction, must call OssDdp.reduce") if param in optim.param_to_rank.keys():
if not self.accumulate_grads: return optim, optim.param_to_rank[param]
self.need_reduction = True
if self.broadcast_buffers and len(list(self.module.buffers())) > 0:
self._sync_buffers()
return self.module(*inputs, **kwargs) assert False, "This parameter is not present in an optimizer, this should not happen"
return (None, -1)
def reduce(self) -> None: def _get_reduce_fn(
self, index: int, param: torch.Tensor, should_bucket: bool, dst_rank: int, optimizer: OSS
) -> Callable:
""" """
This function must be called explicitly after backward to reduce Two possible backward hooks for a given parameter: either directly reduce to the appropriate rank,
gradients. There is no automatic hook like c10d. or contribute to a bucket and reduce when the bucket is full.
Either way a delayed action is necessary and is passed as a callback.
""" """
assert self.module.training, "Cannot call reduce in eval"
if not self.need_reduction or self.accumulate_grads: def gatekeeper() -> None:
return # Make sure that all the asynchronous calls have concluded before moving on. Consume the futures
# and execute the delayed actions (release gradients, unroll the buckets)
Variable._execution_engine.queue_callback(optimizer._consume_work_handles)
# Reset all the grad reduce and bucket state flags
self._grad_to_be_reduced = [True] * len(self._grad_to_be_reduced)
def reduce_direct(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
# Make sure that this is not fired twice
self._grad_to_be_reduced[index] = False
param.grad /= self.world_size
# Future work includes clearing up the buffer if possible
def cleanup() -> None:
if dst_rank != self.global_rank:
param.grad = None
# Async reduce for this buffer, log the future
optimizer.work_handles.append(
Workhandle(
handle=dist.reduce(
tensor=param.grad.data, dst=dst_rank, group=self.process_group, async_op=True
),
callback=cleanup,
)
)
self.need_reduction = False # If all the reduce operations have been called, add the gatekeeper
if len(optimizer.work_handles) == optimizer._max_work_handles:
gatekeeper()
with torch.no_grad(): # Bucket, update status, and possibly unroll the results
for device, per_device in self.sharded_optimizer.per_device_params.items(): def reduce_bucket(*_: Any) -> None:
self._reduce_grads_task( # Skip gradient reduction, do not alter status flags
self._reduce_buffers[device], if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
per_device, assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
group=self.process_group,
self_rank=self.rank,
world_size=self.world_size,
)
@staticmethod # Make sure that this is not fired twice
def _reduce_grads_task( self._grad_to_be_reduced[index] = False
buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]], group: Any, self_rank: int, world_size: int
) -> None: # Copy to the flat buffer, update the buffer state
"""Helper to reduce a list of params. The params are sorted by size, smallest first, which allows for bucket = optimizer.buckets[param.device][dst_rank]
an opportunistic bucketing.
assert bucket.append(param, use_gradient=True), "Bucket overflow: max %s - current %s - adding %s" % (
NOTE: All param gradients are assumed to exist""" bucket.max_size,
bucket.current_offset,
buffer_size = buffers[0].numel() param.grad.numel(),
bucket_requests = []
requests = []
for (rank, params), buffer in zip(enumerate(per_rank_params), buffers):
# All the params are sorted per rank and per increasing size
if len(params) == 0:
continue
for p in params:
if p.grad is None:
p.grad = torch.zeros_like(p)
global_rank = OSS.get_global_rank(group, rank)
# Copy small gradients into per-GPU buffers and then async reduce
i_bucketed = 0 # the number of tensors packed in the buffer
offset = 0
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size:
end = offset + params[i_bucketed].numel()
buffer[offset:end].copy_(params[i_bucketed].grad.data.view(-1)) # type: ignore
offset = end
i_bucketed += 1
if i_bucketed > 0:
buffer.div_(world_size)
bucket_requests.append(
(
dist.reduce(tensor=buffer, dst=global_rank, group=group, async_op=True), # type: ignore
rank,
)
) )
# Directly reduce the other grads if bucket.full():
for p in params[i_bucketed:]:
p.grad = cast(Tensor, p.grad)
if p.grad.requires_grad:
raise RuntimeError("DistributedDataParallel only works with gradients that don't require grad")
p.grad.div_(world_size) def unwrap() -> None:
requests.append(dist.reduce(tensor=p.grad, dst=global_rank, group=group, async_op=True)) # type: ignore for flat in bucket.params:
if dst_rank != self.global_rank:
# this rank is not the owner, release the grad
flat.param.grad = None
else:
# this rank is the owner, unroll the results
assert flat.param.grad is not None
# Unroll the initial packed small gradients, as soon as possible flat.param.grad.data.copy_(
for future, rank in bucket_requests: bucket.buffer[flat.start : flat.stop].view_as(flat.param.data), non_blocking=True
future.wait() )
if rank == self_rank: bucket.reset()
i_bucketed = 0 # the number of tensors packed in the buffer
offset = 0
params = per_rank_params[rank]
buffer = buffers[rank]
while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size: bucket.buffer /= self.world_size
end = offset + params[i_bucketed].numel()
params[i_bucketed].grad.data.copy_(buffer[offset:end].view_as(params[i_bucketed])) # type: ignore
offset = end
i_bucketed += 1
# Make sure that we're done with this device before moving on and cleaning the unused params optimizer.work_handles.append(
_ = list(map(lambda x: x.wait(), requests)) Workhandle(
handle=dist.reduce(
tensor=bucket.buffer, dst=dst_rank, group=self.process_group, async_op=True,
),
callback=unwrap,
)
)
# If all the reduce operations have been called, add the gatekeeper
if len(optimizer.work_handles) == optimizer._max_work_handles:
gatekeeper()
def _sync_buffers(self) -> None: return reduce_bucket if should_bucket else reduce_direct
def _setup_backward_hooks(self) -> None:
""" """
Sync all the param buffers in between ranks. Attach a reduce function to each grad-requiring parameter.
TODO: Could be worth bucketing ? This makes the gradient reduction automatic whenever there's a backward pass
""" """
_ = list(
map( # Go through the parameters, attach the hook
lambda x: x.wait(), for sharded_optimizer in self.sharded_optimizers:
map( for param, should_bucket in sharded_optimizer.should_bucket_param.items():
lambda x: dist.broadcast(x, self.authoritative_rank, self.process_group, async_op=True), if param.grad is not None and param.grad.requires_grad:
self.module.buffers(), raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad")
),
) # Register the hook to the next function in line,
) # so that the hook is fired when this grad has properly been computed
p_tmp = param.expand_as(param)
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0]
dst_rank = sharded_optimizer.param_to_rank[param]
index = len(self._grad_accs)
grad_acc.register_hook(self._get_reduce_fn(index, param, should_bucket, dst_rank, sharded_optimizer))
self._grad_accs.append(grad_acc) # keep this function in scope
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Any, Dict, Optional from typing import Dict
import torch import torch
from torch.cuda.amp import GradScaler as TorchGradScaler from torch.cuda.amp import GradScaler as TorchGradScaler
...@@ -32,15 +32,15 @@ class ShardedGradScaler(TorchGradScaler): ...@@ -32,15 +32,15 @@ class ShardedGradScaler(TorchGradScaler):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def step(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> Optional[float]: def unscale_(self, optimizer: Optimizer) -> None:
assert isinstance(optimizer, OSS), "ShardedGradScaler is to be used in combination with a sharded optimizer" assert isinstance(optimizer, OSS), "ShardedGradScaler is to be used in combination with a sharded optimizer"
# Re-use the GradSCaler machinery, but make sure that the status is sync'ed in between the ranks # Call the upstream unscale_ method which will only act on this rank's gradients
super().unscale_(optimizer)
# Synchronize the detected inf across the ranks
optimizer_state = self._per_optimizer_states[id(optimizer)] optimizer_state = self._per_optimizer_states[id(optimizer)]
handles = [dist.all_reduce(v, async_op=True) for v in optimizer_state["found_inf_per_device"].values()] handles = [dist.all_reduce(v, async_op=True) for v in optimizer_state["found_inf_per_device"].values()]
# Make sure that the calls are done before moving out # Make sure that the calls are done before moving out
_ = list(map(lambda x: x.wait(), handles)) _ = list(map(lambda x: x.wait(), handles))
# Call Torch's GradScaler in turn, states have been synchronized across ranks
return super().step(optimizer, *args, **kwargs)
...@@ -16,7 +16,7 @@ import torch.distributed as dist ...@@ -16,7 +16,7 @@ import torch.distributed as dist
from torch.nn import Parameter from torch.nn import Parameter
from torch.optim import SGD, Optimizer from torch.optim import SGD, Optimizer
from .utils import broadcast_object, recursive_copy_to_device from .utils import Bucket, Workhandle, broadcast_object, recursive_copy_to_device
__all__ = ["OSS"] __all__ = ["OSS"]
...@@ -73,7 +73,7 @@ class OSS(Optimizer): ...@@ -73,7 +73,7 @@ class OSS(Optimizer):
super().__init__(params, default) super().__init__(params, default)
self.in_super_constructor = False self.in_super_constructor = False
# Partition information. lazy evaluation, computed if requested # Partition information. lazy evaluation, computed when requested
self._per_device_params: Dict[torch.device, List[List[Parameter]]] = OrderedDict() # device, rank, params self._per_device_params: Dict[torch.device, List[List[Parameter]]] = OrderedDict() # device, rank, params
self._param_rank: Dict[torch.Tensor, int] = {} self._param_rank: Dict[torch.Tensor, int] = {}
self._partition_parameters: List[List[dict]] = [] self._partition_parameters: List[List[dict]] = []
...@@ -88,22 +88,26 @@ class OSS(Optimizer): ...@@ -88,22 +88,26 @@ class OSS(Optimizer):
# - Sync local and global param_groups keys # - Sync local and global param_groups keys
for global_group, local_group in zip(self.param_groups, self.optim.param_groups): for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
for k, v in local_group.items(): for key, value in local_group.items():
if k != "params": if key != "params":
global_group[k] = v global_group[key] = value
# Optional consolidated optimizer state # Optional consolidated optimizer state
self._all_states: List[Dict[str, Any]] = [] self._all_states: List[Dict[str, Any]] = []
# Current default device is set by the parameters allocated to this rank # Current default device is set by the parameters allocated to this rank
self._device = self.partition_parameters()[self.rank][0]["params"][0].device self._device = self.partition_parameters()[self.rank][0]["params"][0].device
self._broadcast_buffers: Dict[torch.device, List[torch.Tensor]] = {} self.buckets: Dict[torch.device, List[Bucket]] = {}
for device, per_device in self.per_device_params.items(): for device, per_device in self.per_device_params.items():
# Allocate one buffer per rank and per device to group the small parameters # Allocate one buffer per rank and per device to group the small parameters
self._broadcast_buffers[device] = [ self.buckets[device] = [
torch.zeros(broadcast_buffer_size, dtype=per_device[0][0].dtype, device=device) Bucket(buffer=torch.zeros(broadcast_buffer_size, dtype=per_device[0][0].dtype, device=device))
for _ in range(len(per_device)) for _ in range(len(per_device))
] ]
self.should_bucket_param: Dict[torch.Tensor, bool] = {}
self.work_handles: List[Workhandle] = []
self._max_work_handles = -1
self._setup_bucket_strategy()
# Partition helpers # Partition helpers
def partition_parameters(self) -> List[List[dict]]: def partition_parameters(self) -> List[List[dict]]:
...@@ -150,9 +154,9 @@ class OSS(Optimizer): ...@@ -150,9 +154,9 @@ class OSS(Optimizer):
self._per_device_params[device][self.param_to_rank[param]] += [param] self._per_device_params[device][self.param_to_rank[param]] += [param]
# Sort param_lists by size # Sort param_lists by size
for k in self._per_device_params.keys(): for device in self._per_device_params.keys():
for r in self._per_device_params[k]: for rank_params in self._per_device_params[device]:
r.sort(key=lambda x: x.numel()) rank_params.sort(key=lambda x: x.numel())
return self._per_device_params return self._per_device_params
...@@ -164,6 +168,9 @@ class OSS(Optimizer): ...@@ -164,6 +168,9 @@ class OSS(Optimizer):
for param_group in param_groups: for param_group in param_groups:
for param in param_group["params"]: for param in param_group["params"]:
self._param_rank[param] = rank self._param_rank[param] = rank
logging.debug("ZeRO: Parameters dispatched to ranks %s " % list(self._param_rank.values()))
return self._param_rank return self._param_rank
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs. # NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
...@@ -181,20 +188,16 @@ class OSS(Optimizer): ...@@ -181,20 +188,16 @@ class OSS(Optimizer):
self._sync_param_groups() self._sync_param_groups()
# Run the optimizer step on this shard only: # Run the optimizer step on this shard only:
self._free_other_grads()
if closure is not None: if closure is not None:
loss = self.optim.step(closure=closure, **kwargs) # type: ignore loss = self.optim.step(closure=closure, **kwargs) # type: ignore
else: else:
loss = self.optim.step(**kwargs) loss = self.optim.step(**kwargs)
# Depending on the DDP engine used, gradients specific to other ranks may still be loaded
self._free_other_grads()
# Sync all the updated shards in between the ranks # Sync all the updated shards in between the ranks
with torch.no_grad(): self._broadcast_params()
for (
device,
device_params,
) in self.per_device_params.items(): # all the params on this device (inc all ranks)
self._broadcast_params(self._broadcast_buffers[device], device_params)
# Sync hypothethical new results from the wrapped optimizer to the exposed param_groups # Sync hypothethical new results from the wrapped optimizer to the exposed param_groups
self._sync_param_groups(local_to_global=True) self._sync_param_groups(local_to_global=True)
...@@ -489,61 +492,107 @@ class OSS(Optimizer): ...@@ -489,61 +492,107 @@ class OSS(Optimizer):
for t in p["params"]: for t in p["params"]:
t.grad = None t.grad = None
def _broadcast_params(self, buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]]) -> None: def _broadcast_params(self) -> None:
"""Helper function to broadcast all the parameters from a given device""" """Helper function to broadcast all the parameters from a given device"""
buffer_size = buffers[0].numel()
bucket_requests = [] # The unroll callback is called when the broadcast is done.
direct_requests = [] # If this rank is a recipiendary and the call was bucketed, the results from the broadcast are unrolled
# onto the corresponding parameters.
# Bucket and issue all the async calls def get_unroll_callback(src_rank: int, bucket: Bucket) -> Callable:
for (src_rank, params), buffer in zip(enumerate(per_rank_params), buffers): def unroll() -> None:
global_src_rank = self.get_global_rank(self.group, src_rank) if src_rank != self.rank:
for flat in bucket.params:
# Copy small parameters into per-GPU buffers and then async broadcast flat.param.data.copy_(
offset = 0 bucket.buffer[flat.start : flat.stop].view_as(flat.param.data), non_blocking=True
bucket_sent = False
bucket_params = []
# All the params are sorted per rank and per increasing size
for p in params:
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
if not bucket_sent and offset + p.numel() < buffer_size:
end = offset + p.numel()
buffer[offset:end].copy_(p.data.view(-1))
bucket_params.append((p, offset, end))
offset = end
else:
if offset > 0 and not bucket_sent:
bucket_requests.append(
(
dist.broadcast(tensor=buffer, src=global_src_rank, group=self.group, async_op=True),
src_rank,
bucket_params,
)
) )
bucket_sent = True bucket.reset()
direct_requests.append( return unroll
dist.broadcast(tensor=p.data, src=global_src_rank, group=self.group, async_op=True)
)
# Catch a trailing bucket with torch.no_grad():
if not bucket_sent: for (
bucket_requests.append( device,
( device_params,
dist.broadcast(tensor=buffer, src=global_src_rank, group=self.group, async_op=True), ) in self.per_device_params.items(): # all the params on this device (inc all ranks)
src_rank,
bucket_params, buckets = self.buckets[device]
)
)
# Unroll the initial packed small parameters # Bucket and issue all the async calls
for work_handle, src_rank, bucket_params in bucket_requests: for (src_rank, params), bucket in zip(enumerate(device_params), buckets):
work_handle.wait() global_src_rank = self.get_global_rank(self.group, src_rank)
if src_rank != self.rank:
for p, offset, end in bucket_params: for param in params:
p.data.copy_(buffers[src_rank][offset:end].view_as(p.data)) # Bucket broadcast
if self.should_bucket_param[param]:
assert bucket.append(param), "Bucket overflow: max %s - current %s - adding %s" % (
bucket.max_size,
bucket.current_offset,
param.numel(),
)
if bucket.full():
self.work_handles.append(
Workhandle(
handle=dist.broadcast(
tensor=bucket.buffer, src=global_src_rank, group=self.group, async_op=True
),
callback=get_unroll_callback(src_rank, bucket),
)
)
# Direct
else:
self.work_handles.append(
Workhandle(
handle=dist.broadcast(
tensor=param.data, src=global_src_rank, group=self.group, async_op=True
),
callback=None,
)
)
self._consume_work_handles()
def _consume_work_handles(self) -> None:
""" Consume all the futures which are tied to this optimizer's buckets.
We start from the first/older ones, since they are the most likely to be ready and non-blocking
"""
for work_handle in self.work_handles:
work_handle.handle.wait()
if work_handle.callback is not None:
work_handle.callback()
self.work_handles.clear()
def _setup_bucket_strategy(self) -> None:
""" Tag parameters to either bucket them or broadcast/reduce them directly. The parameters are ordered
(smallest first), the bucket will hold the smallest elements, the remaining ones will be directly sent
over the wire.
Generating the partition once and for all allows us to save some time at runtime, and to know when all the
network requests have been issued.
"""
# Unroll all the async work items, just in case for device, per_rank_params in self.per_device_params.items():
_ = list(map(lambda x: x.wait(), direct_requests)) for dst_rank, params in enumerate(per_rank_params):
offset = 0
bucket_size = self.buckets[device][dst_rank].max_size
for param in params:
if (offset + param.numel()) < bucket_size:
# This parameter is small enough to fit in the remaining size of the bucket
self.should_bucket_param[param] = True
offset += param.numel()
else:
# The parameters are sorted by size, so all the following parameters
# will be too big and can be skipped
self.should_bucket_param[param] = False
# Register the max offset for this buffer
self.buckets[device][dst_rank].max_offset = offset
# Determine the max work handles in flight:
# - all the direct reduce/broadcast + 1 bucket
self._max_work_handles = sum(not value for value in self.should_bucket_param.values()) + 1
...@@ -4,13 +4,70 @@ ...@@ -4,13 +4,70 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import io import io
from typing import Any, Dict from typing import Any, Callable, Dict, List, Optional
import torch import torch
from torch._six import container_abcs from torch._six import container_abcs
import torch.distributed as dist import torch.distributed as dist
class Workhandle:
def __init__(self, handle: Any, callback: Optional[Callable]) -> None:
self.handle = handle
self.callback = callback
class FlatParam:
def __init__(self, tensor: torch.Tensor, start: int, stop: int) -> None:
self.param = tensor
self.start = start
self.stop = stop
class Bucket:
"""
Helper class to simplify the handling of broadcast or reduce buckets
"""
def __init__(self, buffer: torch.Tensor) -> None:
# The actual flat tensor
self.buffer = buffer
self.max_size = buffer.numel()
# Handles to the params and their position in this tensor, can be useful for a callback
self.params: List[FlatParam] = []
# Current status for this buffer
self.current_offset = 0
self.max_offset = 0
def reset(self) -> None:
""" empty the bucket """
self.current_offset = 0
self.params.clear()
def append(self, tensor: torch.Tensor, use_gradient: bool = False) -> bool:
""" add a tensor to the bucket """
end = self.current_offset + tensor.numel()
if end > self.max_size:
return False
if use_gradient:
assert tensor.grad is not None
data_source = tensor.grad.data if use_gradient else tensor.data # type: ignore # mypy is drunk
self.buffer[self.current_offset : end].copy_(data_source.view(-1))
self.params.append(FlatParam(tensor=tensor, start=self.current_offset, stop=end))
self.current_offset = end
return True
def full(self) -> bool:
""" is the bucket full ? """
return self.current_offset == self.max_offset
# Credits: classy_vision/generic/distributed_util.py # Credits: classy_vision/generic/distributed_util.py
def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.device) -> Any: def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.device) -> Any:
""" """
......
...@@ -324,7 +324,7 @@ class Tensor: ...@@ -324,7 +324,7 @@ class Tensor:
def coalesce(self) -> Tensor: ... def coalesce(self) -> Tensor: ...
def conj(self) -> Tensor: ... def conj(self) -> Tensor: ...
def contiguous(self) -> Tensor: ... def contiguous(self) -> Tensor: ...
def copy_(self, other: Tensor) -> None: ... def copy_(self, other: Tensor, non_blocking: Optional[_bool]=False) -> None: ...
def cos(self) -> Tensor: ... def cos(self) -> Tensor: ...
def cos_(self) -> Tensor: ... def cos_(self) -> Tensor: ...
def cosh(self) -> Tensor: ... def cosh(self) -> Tensor: ...
......
...@@ -12,3 +12,4 @@ class GradScaler(object): ...@@ -12,3 +12,4 @@ class GradScaler(object):
def _unscale_grads_(self, optimizer: Optimizer, inv_scale: Tensor, found_inf: Tensor, allow_fp16: bool) -> Dict[device, Tensor]:... def _unscale_grads_(self, optimizer: Optimizer, inv_scale: Tensor, found_inf: Tensor, allow_fp16: bool) -> Dict[device, Tensor]:...
def step(self, optimizer: Optimizer, *args: Any, **kwargs: Any): ... def step(self, optimizer: Optimizer, *args: Any, **kwargs: Any): ...
def update(self, new_scale: Optional[float]=None): ... def update(self, new_scale: Optional[float]=None): ...
def unscale_(self, optimizer: Optimizer) -> None: ...
...@@ -28,8 +28,10 @@ class ReduceOp: ...@@ -28,8 +28,10 @@ class ReduceOp:
def get_rank(group: Any = None) -> int: ... def get_rank(group: Any = None) -> int: ...
def get_world_size(group: Any = None) -> int: ... def get_world_size(group: Any = None) -> int: ...
def get_backend(group: Optional[Any] = None) -> Any: ...
def broadcast(tensor: Tensor, src: Any, group: Any, async_op: Any = False): ... def broadcast(tensor: Tensor, src: Any, group: Any, async_op: Any = False): ...
def gather(tensor: Tensor, gather_list: Optional[List[Tensor]], dst: Any, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ...
def reduce(tensor: Tensor, dst: Any, op: Optional[Any]=ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ...
def is_initialized() -> bool: ... def is_initialized() -> bool: ...
......
...@@ -8,7 +8,9 @@ Testing OssDdp class. ...@@ -8,7 +8,9 @@ Testing OssDdp class.
""" """
import tempfile import tempfile
from typing import List
import numpy as np
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -16,18 +18,20 @@ import torch.multiprocessing as mp ...@@ -16,18 +18,20 @@ import torch.multiprocessing as mp
from torch.nn import Linear, Sequential from torch.nn import Linear, Sequential
from fairscale.nn.data_parallel import ShardedDataParallel from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs required") skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs required")
from contextlib import suppress
def test_on_cpu(): def test_step_on_cpu():
run_test(backend=dist.Backend.GLOO, device=torch.device("cpu")) run_test(backend=dist.Backend.GLOO, device=torch.device("cpu"), world_size=4)
@skip_if_no_cuda @skip_if_no_cuda
@skip_if_single_gpu @skip_if_single_gpu
def test_on_gpu(): def test_step_on_gpu():
run_test(backend=dist.Backend.NCCL, device=torch.device("cuda")) run_test(backend=dist.Backend.NCCL, device=torch.device("cuda"))
...@@ -37,46 +41,78 @@ def run_one_step(rank, world_size, backend, device, temp_file_name): ...@@ -37,46 +41,78 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
if device == torch.device("cuda"): if device == torch.device("cuda"):
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
# Any model works. Add one different buffer per rank torch.manual_seed(rank)
model = Sequential(Linear(2, 3)).to(device) np.random.seed(rank)
model.register_buffer("test_buffer", torch.ones((1)) * rank)
def check(broadcast_buffers: bool, grad_accumulation: bool = False) -> None:
def weights_init(m): # Any model works. Add one different buffer per rank
if isinstance(m, Linear): model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
torch.nn.init.constant_(m.weight.data, 1.0) model.register_buffer("test_buffer", torch.ones((1)) * rank)
torch.nn.init.constant_(m.bias.data, 1.0) model.to(device)
model.apply(weights_init) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
model.to(device) ddp_model = ShardedDataParallel(model, optimizer, broadcast_buffers=broadcast_buffers)
ddp = ShardedDataParallel( def check_same_model_params(same_params: bool):
module=model, # Check that all the params are the same on all ranks
optimizer=torch.optim.SGD, # This should be true with and without broadcast_buffers, we don't have any real buffer here
optimizer_params={"lr": 0.01, "momentum": 0.99}, receptacle: List[torch.Tensor] = []
world_size=world_size,
broadcast_buffers=True, if dist.get_backend() != "nccl":
) for pg in optimizer.param_groups:
optimizer = ddp.optimizer for p in pg["params"]:
model = ddp.module # Check the params
receptacle = [p.clone() for _ in range(world_size)] if rank == 0 else []
# Different input per rank, allows for checking that the gradients have been properly reduced dist.gather(p, receptacle, dst=0)
input_tensor = (torch.ones((64, 2)) * rank).to(device) if rank == 0:
output = ddp(input_tensor).abs().sum() for sync_p in receptacle[1:]:
output.backward() if same_params:
ddp.reduce() assert torch.all(torch.eq(receptacle[0], sync_p)), "Models differ in between ranks"
else:
# Check that all the grads have been populated, for the shard assert not torch.all(
for pg in optimizer.optim.param_groups: torch.eq(receptacle[0], sync_p)
for param in pg["params"]: ), "Gradients should not have been synced"
if param.shape == torch.Size([3, 2]):
assert param.grad[0, 0].cpu() == torch.tensor([32.0]) # Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
if param.shape == torch.Size([3]): if broadcast_buffers:
assert param.grad[0].cpu() == torch.tensor([64.0]) for b in ddp_model.buffers():
receptacle = [b.clone() for _ in range(world_size)] if rank == 0 else []
# Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0) dist.gather(b, receptacle, dst=0)
for b in model.buffers(): if rank == 0:
assert b.cpu().item() == 0.0 for sync_b in receptacle[1:]:
if same_params:
assert torch.all(torch.eq(receptacle[0], sync_b)), "Models differ in between ranks"
else:
assert not torch.all(
torch.eq(receptacle[0], sync_b)
), "Gradients should not have been synced"
assert b.cpu().item() == 0.0
# The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
check_same_model_params(same_params=True)
# Optim loop
def closure():
optimizer.zero_grad()
with ddp_model.no_sync() if grad_accumulation else suppress():
input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor).abs().sum()
loss.backward()
return loss
# The models should stay the same in between the ranks
for i in range(5):
_ = optimizer.step(closure=closure)
# when running on cpu/gloo the "nodes" are not really different
same_params = device == torch.device("cpu") or grad_accumulation
check_same_model_params(same_params=same_params)
check(broadcast_buffers=False)
check(broadcast_buffers=True)
check(broadcast_buffers=False, grad_accumulation=True)
check(broadcast_buffers=True, grad_accumulation=True)
dist.destroy_process_group() dist.destroy_process_group()
...@@ -85,33 +121,116 @@ def run_test(backend, device, world_size=2): ...@@ -85,33 +121,116 @@ def run_test(backend, device, world_size=2):
mp.spawn(run_one_step, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True) mp.spawn(run_one_step, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
def run_eval_mode(_unused): def run_test_two_inputs(rank, world_size, backend, device, temp_file_name):
""" Testing eval mode make sure this is no asserts. """ url = "file://" + temp_file_name
dist.init_process_group( dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
init_method=f"file://{tempfile.mkstemp()[1]}", backend=dist.Backend.GLOO, rank=0, world_size=1 if device == torch.device("cuda"):
) torch.cuda.set_device(rank)
model = Sequential(Linear(2, 3), Linear(3, 4))
optimizer_params = {"lr": 0.1, "momentum": 0.99} torch.manual_seed(rank)
ddp = ShardedDataParallel(model, torch.optim.SGD, optimizer_params, 1, broadcast_buffers=False) np.random.seed(rank)
optimizer = ddp.optimizer
class _DoubleInput(torch.nn.Module):
ddp.eval() def __init__(self):
for _ in range(5): super().__init__()
input_tensor = torch.rand((64, 2)) self.mlp = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
output = ddp(input_tensor)
def forward(self, x, y):
ddp.train() x1 = self.mlp(x)
try: x2 = self.mlp(y)
for _ in range(5): return torch.cat((x1, x2), dim=1)
input_tensor = torch.rand((64, 2))
output = ddp(input_tensor) model = _DoubleInput().to(device)
except RuntimeError:
pass optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
else: ddp_model = ShardedDataParallel(model, optimizer)
assert False, "Multiple forward passes on training mode should not pass"
# Optim loop
def closure():
optimizer.zero_grad()
input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor, input_tensor).abs().sum()
loss.backward()
return loss
# The models should stay the same in between the ranks
for i in range(5):
_ = optimizer.step(closure=closure)
dist.destroy_process_group() dist.destroy_process_group()
def test_eval_mode(): def test_inputs():
mp.spawn(run_eval_mode, args=(), join=True) # Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size = 2
backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cpu"
mp.spawn(run_test_two_inputs, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
def test_ddp_attributes():
# Check that ShardedDDP exposes the same attributes as Pytorch's DDP
# - is multi_device_module
# - device_type
url = "file://" + tempfile.mkstemp()[1]
dist.init_process_group(init_method=url, backend="gloo", rank=0, world_size=1)
model = Sequential(Linear(2, 3), Linear(3, 3))
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
assert hasattr(ddp_model, "is_multi_device_module")
assert hasattr(ddp_model, "device_type")
dist.destroy_process_group()
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
if device == torch.device("cuda"):
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
class _DoubleInput(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
def forward(self, x, y):
x1 = self.mlp(x)
x2 = self.mlp(y)
return torch.cat((x1, x2), dim=1)
model = _DoubleInput().to(device)
parameters = list(model.parameters())
optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=0.01, momentum=0.99)
optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2])
# Optim loop
def closure():
optimizer.zero_grad()
input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor, input_tensor).abs().sum()
loss.backward()
return loss
# The models should stay the same in between the ranks
for i in range(5):
_ = optimizer.step(closure=closure)
dist.destroy_process_group()
def test_two_optimizers():
# Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size = 2
backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cpu"
mp.spawn(run_test_two_inputs, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment