Unverified Commit 341d8b2b authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] OSS/SDP : bucketing (#122)

Same bucketing strategy for OSS and SDP:
sort everything ahead of time, per rank and per size, smaller tensors first. Bucket the smallest elements in a fixed buffer, send async, then send all the others async, and get back to the bucket. Once done then scatter the contents if needed
parent 6e7ad798
...@@ -165,6 +165,8 @@ def train( ...@@ -165,6 +165,8 @@ def train(
print("[Regression Test] VALID") print("[Regression Test] VALID")
dist.destroy_process_group() # type: ignore
if __name__ == "__main__": if __name__ == "__main__":
...@@ -246,7 +248,10 @@ if __name__ == "__main__": ...@@ -246,7 +248,10 @@ if __name__ == "__main__":
backend, backend,
True, # OSS True, # OSS
True, # SDP True, # SDP
False, # no regression check args.check_regression,
-1, # Not checking SDP for speed regression for now, still slower than OSS
args.reference_memory,
args.reference_loss,
), ),
nprocs=args.world_size, nprocs=args.world_size,
join=True, join=True,
......
...@@ -11,7 +11,7 @@ Adopted from LegacyDistributedDataParallel module from fairseq. ...@@ -11,7 +11,7 @@ Adopted from LegacyDistributedDataParallel module from fairseq.
from contextlib import contextmanager from contextlib import contextmanager
import copy import copy
from typing import Any, Dict, Generator, List, Optional, Type, cast from typing import Any, Dict, Generator, List, Type, cast
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
...@@ -39,7 +39,7 @@ class ShardedDataParallel(nn.Module): ...@@ -39,7 +39,7 @@ class ShardedDataParallel(nn.Module):
distributed gradient reduction. If None, the default WORLD process group distributed gradient reduction. If None, the default WORLD process group
will be used. will be used.
buffer_size (int, optional): number of elements to buffer before buffer_size (int, optional): number of elements to buffer before
performing reduce (default: 256M). Used to reduce multiple small performing reduce (default: 512k). Used to reduce multiple small
params to avoid communication overhead. params to avoid communication overhead.
""" """
...@@ -51,7 +51,7 @@ class ShardedDataParallel(nn.Module): ...@@ -51,7 +51,7 @@ class ShardedDataParallel(nn.Module):
world_size: int, world_size: int,
broadcast_buffers: bool, broadcast_buffers: bool,
process_group: Any = None, process_group: Any = None,
buffer_size: int = 2 ** 28, buffer_size: int = 2 ** 19,
): ):
super().__init__() super().__init__()
...@@ -62,10 +62,6 @@ class ShardedDataParallel(nn.Module): ...@@ -62,10 +62,6 @@ class ShardedDataParallel(nn.Module):
self.broadcast_buffers = broadcast_buffers self.broadcast_buffers = broadcast_buffers
self.authoritative_rank = 0 self.authoritative_rank = 0
# Never use a bigger buffer than the number of model params
self.buffer_size = min(buffer_size, sum(p.numel() for p in self.module.parameters()))
self.buffer: Optional[Tensor] = None
# Flag used to make sure we only reduce gradients one time in the execution engine # Flag used to make sure we only reduce gradients one time in the execution engine
self.need_reduction = False self.need_reduction = False
...@@ -76,6 +72,18 @@ class ShardedDataParallel(nn.Module): ...@@ -76,6 +72,18 @@ class ShardedDataParallel(nn.Module):
# Build the sharded optimizer # Build the sharded optimizer
self.sharded_optimizer = OSS(self.module.parameters(), optim=optimizer, group=process_group, **optimizer_params) self.sharded_optimizer = OSS(self.module.parameters(), optim=optimizer, group=process_group, **optimizer_params)
# Allocate reduce buffers
# - Never use a bigger buffer than the number of model params
buffer_size = min(buffer_size, sum(p.numel() for p in self.module.parameters()))
self._reduce_buffers: Dict[torch.device, List[torch.Tensor]] = {}
# - One buffer per rank per device
for device, per_device in self.sharded_optimizer.per_device_params.items():
buffer_dtype = per_device[0][0].dtype
self._reduce_buffers[device] = [
torch.zeros(buffer_size, dtype=buffer_dtype, device=device) for _ in range(len(per_device))
]
# Sanity checks # Sanity checks
assert len(self.sharded_optimizer.param_to_rank) == len( assert len(self.sharded_optimizer.param_to_rank) == len(
list(self.module.parameters()) list(self.module.parameters())
...@@ -126,85 +134,92 @@ class ShardedDataParallel(nn.Module): ...@@ -126,85 +134,92 @@ class ShardedDataParallel(nn.Module):
""" """
assert self.module.training, "Cannot call reduce in eval" assert self.module.training, "Cannot call reduce in eval"
def reduce_grads(params: List[Parameter], params_rank: int) -> None: if not self.need_reduction or self.accumulate_grads:
""" Helper to reduce a list of params that should fit in the buffer. return
NOTE: All param gradients are assumed to exist"""
assert self.buffer is not None
# Fill in the packed IO buffer self.need_reduction = False
buffer: Tensor = cast(Tensor, self.buffer)
if len(params) > 1: with torch.no_grad():
offset = 0 for device, per_device in self.sharded_optimizer.per_device_params.items():
for p in params: self._reduce_grads_task(
sz = p.numel() self._reduce_buffers[device],
buffer[offset : offset + sz].copy_(p.grad.data.view(-1)) # type: ignore per_device,
offset += sz group=self.process_group,
else: self_rank=self.rank,
# we only have a single grad to reduce world_size=self.world_size,
buffer = params[0].grad.data # type: ignore )
# Reduce @staticmethod
buffer.div_(self.world_size) # type: ignore def _reduce_grads_task(
dist.reduce(tensor=buffer, dst=params_rank, group=self.process_group) # type: ignore buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]], group: Any, self_rank: int, world_size: int
) -> None:
# Copy reduced grads back into their original place, or free corresponding memory """Helper to reduce a list of params. The params are sorted by size, smallest first, which allows for
if params_rank == self.rank: an opportunistic bucketing.
offset = 0
for p in params: NOTE: All param gradients are assumed to exist"""
sz = p.numel()
p.grad.data.copy_(buffer[offset : offset + sz].view_as(p)) # type: ignore buffer_size = buffers[0].numel()
offset += sz bucket_requests = []
else: requests = []
for p in params:
p.grad = None for (rank, params), buffer in zip(enumerate(per_rank_params), buffers):
# All the params are sorted per rank and per increasing size
def reduction_fn() -> None: if len(params) == 0:
# This function only needs to be called once continue
if not self.need_reduction or self.accumulate_grads:
return for p in params:
self.need_reduction = False if p.grad is None:
p.grad = torch.zeros_like(p)
if self.buffer is None:
self.buffer = next(self.module.parameters()).new(self.buffer_size) # type: ignore global_rank = OSS.get_global_rank(group, rank)
for params in self.sharded_optimizer.per_device_params: # Copy small gradients into per-GPU buffers and then async reduce
# Reduce the gradients in buckets i_bucketed = 0 # the number of tensors packed in the buffer
offset = 0
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size:
end = offset + params[i_bucketed].numel()
buffer[offset:end].copy_(params[i_bucketed].grad.data.view(-1)) # type: ignore
offset = end
i_bucketed += 1
if i_bucketed > 0:
buffer.div_(world_size) # type: ignore
bucket_requests.append(
(
dist.reduce(tensor=buffer, dst=global_rank, group=group, async_op=True), # type: ignore
rank,
)
)
# Directly reduce the other grads
for p in params[i_bucketed:]:
p.grad = cast(Tensor, p.grad)
if p.grad.requires_grad:
raise RuntimeError("DistributedDataParallel only works with gradients that don't require grad")
p.grad.div_(world_size) # type: ignore
requests.append(dist.reduce(tensor=p.grad, dst=global_rank, group=group, async_op=True)) # type: ignore
# Unroll the initial packed small gradients, as soon as possible
for future, rank in bucket_requests:
future.wait()
if rank == self_rank:
i_bucketed = 0 # the number of tensors packed in the buffer
offset = 0 offset = 0
buffered_params: List[Parameter] = [] params = per_rank_params[rank]
param_rank: Optional[int] = None buffer = buffers[rank]
for param in params:
last_param_rank: Optional[int] = param_rank while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size:
param_rank = self.sharded_optimizer.param_to_rank[param] end = offset + params[i_bucketed].numel()
if not param.requires_grad: params[i_bucketed].grad.data.copy_(buffer[offset:end].view_as(params[i_bucketed])) # type: ignore
continue offset = end
i_bucketed += 1
if param.grad is None:
param.grad = torch.zeros_like(param) # Make sure that we're done with this device before moving on and cleaning the unused params
if param.grad.requires_grad: _ = list(map(lambda x: x.wait(), requests))
raise RuntimeError("DistributedDataParallel only works with gradients that don't require grad")
sz = param.numel()
if sz > self.buffer.numel():
# reduce big params directly
assert param_rank is not None
reduce_grads([param], cast(int, param_rank))
else:
# smaller params are packed together from the same device
# and same rank.
if offset + sz > self.buffer.numel() or (
last_param_rank is not None and last_param_rank != param_rank
):
assert last_param_rank is not None
reduce_grads(buffered_params, cast(int, last_param_rank))
offset = 0
buffered_params.clear()
buffered_params.append(cast(Parameter, param))
offset += sz
if len(buffered_params) > 0:
assert param_rank is not None
reduce_grads(buffered_params, cast(int, param_rank))
reduction_fn()
def _sync_buffers(self) -> None: def _sync_buffers(self) -> None:
""" """
......
...@@ -49,6 +49,8 @@ class OSS(Optimizer): ...@@ -49,6 +49,8 @@ class OSS(Optimizer):
optimizer to shard (default: SGD) optimizer to shard (default: SGD)
group (group): group (group):
torch.distributed group (default: group.WORLD) torch.distributed group (default: group.WORLD)
broadcast_buffer_size (int):
the size of the buffer used to batch the small parameter tensors (default 128k).
""" """
#: The optimizer used for a given shard #: The optimizer used for a given shard
...@@ -56,38 +58,53 @@ class OSS(Optimizer): ...@@ -56,38 +58,53 @@ class OSS(Optimizer):
in_super_constructor: bool in_super_constructor: bool
def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Optional[Any] = None, **default: Any): def __init__(
self,
params: _params_t,
optim: Type[Optimizer] = SGD,
group: Optional[Any] = None,
broadcast_buffer_size: int = 2 ** 17,
**default: Any,
):
# Hold all the model params in the root .param_groups # Hold all the model params in the root .param_groups
self.in_super_constructor = True self.in_super_constructor = True
super().__init__(params, default) super().__init__(params, default)
self.in_super_constructor = False self.in_super_constructor = False
# Partition information. lazy evaluation, computed if requested # Partition information. lazy evaluation, computed if requested
self._per_device_params: List[List[Parameter]] = [] self._per_device_params: OrderedDict[
torch.device, List[List[Parameter]]
] = OrderedDict() # device, rank, params
self._param_rank: Dict[torch.Tensor, int] = {} self._param_rank: Dict[torch.Tensor, int] = {}
self._partition_parameters: List[List[dict]] = [] self._partition_parameters: List[List[dict]] = []
# Build the wrapped optimizer, responsible for a shard of the params # Build the wrapped optimizer, responsible for a shard of the params
self.group = group if group is not None else dist.group.WORLD self.group = group if group is not None else dist.group.WORLD
self.world_size = dist.get_world_size(self.group) self.world_size = dist.get_world_size(self.group)
self.rank = dist.get_rank(self.group) self.rank = dist.get_rank(self.group)
self.global_rank = self.get_global_rank(self.group, self.rank) self.global_rank = self.get_global_rank(self.group, self.rank)
self.optim = optim(self.partition_parameters()[self.rank], **default) self.optim = optim(self.partition_parameters()[self.rank], **default)
# Optional consolidated optimizer state # - Sync local and global param_groups keys
self._all_states: List[Dict[str, Any]] = []
# Current device is set by the parameters allocated to this rank
self._device = self.partition_parameters()[self.rank][0]["params"][0].device
# Sync local and global param_groups keys
for global_group, local_group in zip(self.param_groups, self.optim.param_groups): for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
for k, v in local_group.items(): for k, v in local_group.items():
if k != "params": if k != "params":
global_group[k] = v global_group[k] = v
# Optional consolidated optimizer state
self._all_states: List[Dict[str, Any]] = []
# Current default device is set by the parameters allocated to this rank
self._device = self.partition_parameters()[self.rank][0]["params"][0].device
self._broadcast_buffers: Dict[torch.device, List[torch.Tensor]] = {}
for device, per_device in self.per_device_params.items():
# Allocate one buffer per rank and per device to group the small parameters
self._broadcast_buffers[device] = [
torch.zeros(broadcast_buffer_size, dtype=per_device[0][0].dtype, device=device)
for _ in range(len(per_device))
]
# Partition helpers # Partition helpers
def partition_parameters(self) -> List[List[dict]]: def partition_parameters(self) -> List[List[dict]]:
"""Partitions parameters across distributed data parallel ranks. """Partitions parameters across distributed data parallel ranks.
...@@ -116,22 +133,26 @@ class OSS(Optimizer): ...@@ -116,22 +133,26 @@ class OSS(Optimizer):
return self._partition_parameters return self._partition_parameters
@property @property
def per_device_params(self) -> List[List[Parameter]]: def per_device_params(self) -> Dict[torch.device, List[List[Parameter]]]:
# TODO (Min): The algorithm here can be improved. We are sorting params by device """Sorted list of all the params, first per device then per rank.
# and by rank. Then in reduction_fn below, we pack smaller ones into
# a buffer for reduction.
# We can pre-sort them here and simplify the reduction_fn logic below
# since their size shouldn't change.
Within a list params are sorted per number of elements to allow for an easy bucketing.
"""
if len(self._per_device_params) == 0: if len(self._per_device_params) == 0:
# Go through all params, log them per device
# The ordering is important here, needs to be the same on all ranks
# So that ulterior broadcast calls are matching
for param_group in self.param_groups: for param_group in self.param_groups:
param_lists: OrderedDict = OrderedDict()
for param in param_group["params"]: for param in param_group["params"]:
device = param.device device = param.device
if param_lists.get(device) is None: if self._per_device_params.get(device) is None:
param_lists[device] = [] self._per_device_params[device] = [[] for _ in range(self.world_size)]
param_lists[device] += [param] self._per_device_params[device][self.param_to_rank[param]] += [param]
self._per_device_params = list(param_lists.values())
# Sort param_lists by size
for k in self._per_device_params.keys():
for r in self._per_device_params[k]:
r.sort(key=lambda x: x.numel())
return self._per_device_params return self._per_device_params
...@@ -145,13 +166,6 @@ class OSS(Optimizer): ...@@ -145,13 +166,6 @@ class OSS(Optimizer):
self._param_rank[param] = rank self._param_rank[param] = rank
return self._param_rank return self._param_rank
def get_global_rank(self, group: Any, rank: int) -> int:
if group is dist.group.WORLD:
return rank
else:
global_rank = dist.distributed_c10d._get_global_rank(group, rank) # type: ignore
return global_rank
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs. # NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs. # For example, the apex library contains fused optimizers with a step that supports extra kwargs.
def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]: def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]:
...@@ -174,25 +188,14 @@ class OSS(Optimizer): ...@@ -174,25 +188,14 @@ class OSS(Optimizer):
else: else:
loss = self.optim.step(**kwargs) loss = self.optim.step(**kwargs)
# Sync all the states. Broadcast requests are issued async, we check completeness before moving on # Sync all the updated shards in between the ranks
requests = [] with torch.no_grad():
requires_grad = [] for (
for rank, param_groups in enumerate(self.partition_parameters()): device,
for param_group in param_groups: device_params,
for param in param_group["params"]: ) in self.per_device_params.items(): # all the params on this device (inc all ranks)
# NOTE: Broadcast is in-place and not differentiable self._broadcast_params(self._broadcast_buffers[device], device_params, self.group, self.global_rank)
# Gloo will rightly assert on this operation for any tensor that requires grad.
# We save and restore the grad requirement state to work around that, in our case
# the grad is only useful on the source rank.
global_rank = self.get_global_rank(self.group, rank)
requires_grad.append((param, param.requires_grad))
param.requires_grad = False
requests.append(dist.broadcast(tensor=param, src=global_rank, group=self.group, async_op=True))
for fut, req_grad in zip(requests, requires_grad):
fut.wait()
req_grad[0].requires_grad = req_grad[1]
return loss return loss
def local_state_dict(self) -> dict: def local_state_dict(self) -> dict:
...@@ -322,7 +325,10 @@ class OSS(Optimizer): ...@@ -322,7 +325,10 @@ class OSS(Optimizer):
super().add_param_group(param_group) super().add_param_group(param_group)
if not self.in_super_constructor: if not self.in_super_constructor:
self._partition_parameters.clear() # Force a re-partitioning # Force a re-partitioning
self._partition_parameters.clear()
self._per_device_params.clear()
self._param_rank.clear()
param_groups = self.partition_parameters()[self.rank] param_groups = self.partition_parameters()[self.rank]
if len(param_groups) == len(self.optim.param_groups) + 1: if len(param_groups) == len(self.optim.param_groups) + 1:
...@@ -353,7 +359,6 @@ class OSS(Optimizer): ...@@ -353,7 +359,6 @@ class OSS(Optimizer):
else: else:
# Fetch the optim state from the other replicas # Fetch the optim state from the other replicas
global_rank = self.get_global_rank(self.group, rank) global_rank = self.get_global_rank(self.group, rank)
logging.debug("Receiving state from rank %s ", global_rank)
replica_state = broadcast_object( replica_state = broadcast_object(
empty_buffer, src_rank=global_rank, group=self.group, dist_device=self._device empty_buffer, src_rank=global_rank, group=self.group, dist_device=self._device
) )
...@@ -382,16 +387,93 @@ class OSS(Optimizer): ...@@ -382,16 +387,93 @@ class OSS(Optimizer):
else: else:
global_rank = self.get_global_rank(self.group, rank) global_rank = self.get_global_rank(self.group, rank)
# Discard this tensor/rank, broadcast necessary for syncing # Discard this tensor/rank, broadcast necessary for syncing
logging.debug("Discarding broadcast from rank %s", global_rank)
broadcast_object(empty_buffer, src_rank=global_rank, group=self.group, dist_device=self._device) broadcast_object(empty_buffer, src_rank=global_rank, group=self.group, dist_device=self._device)
def _free_other_grads(self) -> None: def _free_other_grads(self) -> None:
"""Free all the gradients only useful for the other ranks """Free all the gradients only useful for the other ranks
""" """
for i, partition in enumerate(self.partition_parameters()): for rank, partition in enumerate(self.partition_parameters()):
if i == self.rank: if rank == self.rank:
continue continue
for p in partition: for p in partition:
for t in p["params"]: for t in p["params"]:
t.grad = None t.grad = None
@staticmethod
def get_global_rank(group: Any, rank: int) -> int:
if group is dist.group.WORLD:
return rank
else:
global_rank = dist.distributed_c10d._get_global_rank(group, rank) # type: ignore
return global_rank
@staticmethod
def _broadcast_params(
buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]], group: Any, self_rank: int
) -> None:
"""Helper function to broadcast all the parameters from a given device
"""
buffer_size = buffers[0].numel()
restore_require_grad = []
bucket_requests = []
requests = []
# Bucket and issue all the async calls
for (rank, params), buffer in zip(enumerate(per_rank_params), buffers):
# All the params are sorted per rank and per increasing size
if len(params) == 0:
continue
global_rank = OSS.get_global_rank(group, rank)
# Copy small parameters into per-GPU buffers
i_bucketed = 0 # the number of tensors packed in the buffer
offset = 0
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size:
end = offset + params[i_bucketed].numel()
if rank == self_rank:
buffer[offset:end].copy_(params[i_bucketed].data.view(-1)) # type: ignore
offset = end
i_bucketed += 1
if i_bucketed > 0:
future = dist.broadcast(tensor=buffer, src=global_rank, group=group, async_op=True)
if rank != self_rank:
# This request will need to be unrolled
bucket_requests.append((future, rank))
# Directly broadcast the rest
for param in params[i_bucketed:]:
# NOTE: Broadcast is in-place and not differentiable
# Gloo will assert on this operation for any tensor that requires grad.
# We save and restore the grad requirement state to work around that, in our case
# the grad is only useful on the source rank.
if param.requires_grad:
restore_require_grad.append(param)
param.requires_grad = False
requests.append(dist.broadcast(tensor=param, src=global_rank, group=group, async_op=True))
# Unroll the initial packed small parameters
for gate, rank in bucket_requests:
gate.wait()
params = per_rank_params[rank]
buffer = buffers[rank]
i_bucketed = 0 # the number of tensors packed in the buffer
offset = 0
while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size:
end = offset + params[i_bucketed].numel()
params[i_bucketed].data.copy_(buffer[offset:end].view_as(params[i_bucketed])) # type: ignore
offset = end
i_bucketed += 1
# Unroll all the async work items, just in case
_ = list(map(lambda x: x.wait(), requests))
for p in restore_require_grad:
p.requires_grad = True
...@@ -38,8 +38,15 @@ def run_one_step(rank, world_size, backend, device, temp_file_name): ...@@ -38,8 +38,15 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
# Any model works. Add one different buffer per rank # Any model works. Add one different buffer per rank
model = Sequential(Linear(2, 3), Linear(3, 4)).to(device) model = Sequential(Linear(2, 3)).to(device)
model.register_buffer("test_buffer", torch.ones((1)) * rank) model.register_buffer("test_buffer", torch.ones((1)) * rank)
def weights_init(m):
if isinstance(m, Linear):
torch.nn.init.constant_(m.weight.data, 1.0)
torch.nn.init.constant_(m.bias.data, 1.0)
model.apply(weights_init)
model.to(device) model.to(device)
ddp = ShardedDataParallel( ddp = ShardedDataParallel(
...@@ -52,24 +59,26 @@ def run_one_step(rank, world_size, backend, device, temp_file_name): ...@@ -52,24 +59,26 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
optimizer = ddp.optimizer optimizer = ddp.optimizer
model = ddp.module model = ddp.module
input_tensor = torch.rand((64, 2)).to(device) # Different input per rank, allows for checking that the gradients have been properly reduced
output = ddp(input_tensor).abs().sum() / input_tensor.numel() input_tensor = (torch.ones((64, 2)) * rank).to(device)
output = ddp(input_tensor).abs().sum()
output.backward() output.backward()
ddp.reduce() ddp.reduce()
# Check that all the grads have been populated, for the shard # Check that all the grads have been populated, for the shard
if device == torch.device("cuda"):
torch.cuda.synchronize() # flush any remaining cuda op, just in case
for pg in optimizer.optim.param_groups: for pg in optimizer.optim.param_groups:
for param in pg["params"]: for param in pg["params"]:
if param.requires_grad: if param.shape == torch.Size([3, 2]):
assert param.grad.abs().sum().item() > 0.0, "The reduce step should have populated all the gradients" assert param.grad[0, 0].cpu() == torch.tensor([32.0])
if param.shape == torch.Size([3]):
assert param.grad[0].cpu() == torch.tensor([64.0])
# Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0) # Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
for b in model.buffers(): for b in model.buffers():
assert b.cpu().item() == 0.0 assert b.cpu().item() == 0.0
dist.destroy_process_group()
def run_test(backend, device, world_size=2): def run_test(backend, device, world_size=2):
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
...@@ -101,6 +110,8 @@ def run_eval_mode(_unused): ...@@ -101,6 +110,8 @@ def run_eval_mode(_unused):
else: else:
assert False, "Multiple forward passes on training mode should not pass" assert False, "Multiple forward passes on training mode should not pass"
dist.destroy_process_group()
def test_eval_mode(): def test_eval_mode():
mp.spawn(run_eval_mode, args=(), join=True) mp.spawn(run_eval_mode, args=(), join=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment