Unverified Commit 49a198c9 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] Sharded DDP - small refactor and new features (#97)

- rename oss_ddp to ShardedDataParallel
- some refactoring
- ShardedDataParallel owns the sharded optimizer, exposed if need be
- some small perf bumps
parent 2ddce57f
...@@ -102,6 +102,12 @@ run_oss_benchmark: &run_oss_benchmark ...@@ -102,6 +102,12 @@ run_oss_benchmark: &run_oss_benchmark
command: | command: |
python benchmarks/oss.py python benchmarks/oss.py
run_oss_ddp_benchmark: &run_oss_ddp_benchmark
- run:
name: Run OSS DDP Benchmark
command: |
python benchmarks/oss.py --oss_ddp
# ------------------------------------------------------------------------------------- # -------------------------------------------------------------------------------------
# Jobs to run # Jobs to run
# ------------------------------------------------------------------------------------- # -------------------------------------------------------------------------------------
...@@ -252,6 +258,8 @@ jobs: ...@@ -252,6 +258,8 @@ jobs:
- <<: *run_oss_benchmark - <<: *run_oss_benchmark
- <<: *run_oss_ddp_benchmark
workflows: workflows:
......
...@@ -17,7 +17,7 @@ repos: ...@@ -17,7 +17,7 @@ repos:
- id: end-of-file-fixer - id: end-of-file-fixer
- repo: https://github.com/ambv/black - repo: https://github.com/ambv/black
rev: stable rev: 19.10b0
hooks: hooks:
- id: black - id: black
language_version: python3.6 language_version: python3.6
......
...@@ -16,6 +16,7 @@ from torchvision.datasets import FakeData ...@@ -16,6 +16,7 @@ from torchvision.datasets import FakeData
from torchvision.models import resnet101 from torchvision.models import resnet101
from torchvision.transforms import ToTensor from torchvision.transforms import ToTensor
from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim.oss import OSS from fairscale.optim.oss import OSS
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
...@@ -28,21 +29,7 @@ def dist_init(rank, world_size): ...@@ -28,21 +29,7 @@ def dist_init(rank, world_size):
dist.init_process_group(backend=BACKEND, rank=rank, world_size=world_size) dist.init_process_group(backend=BACKEND, rank=rank, world_size=world_size)
def train( def get_problem(rank, data_size, batch_size):
rank: int,
world_size: int,
num_epochs: int = 10,
batch_size: int = 32,
data_size: int = 200,
use_oss: bool = True,
check_regression: bool = True,
reference_speed: float = -1.0,
reference_memory: float = -1.0,
):
# DDP
dist_init(rank, world_size)
# Standard RN101 # Standard RN101
model = resnet101(pretrained=False, progress=True).to(rank) model = resnet101(pretrained=False, progress=True).to(rank)
...@@ -57,14 +44,101 @@ def train( ...@@ -57,14 +44,101 @@ def train(
dataset=FakeData(transform=ToTensor(), size=data_size), batch_size=batch_size, collate_fn=collate dataset=FakeData(transform=ToTensor(), size=data_size), batch_size=batch_size, collate_fn=collate
) )
loss_fn = nn.CrossEntropyLoss() loss_fn = nn.CrossEntropyLoss()
return model, dataloader, loss_fn
def train_oss_ddp(
rank: int, world_size: int, num_epochs: int = 10, batch_size: int = 32, data_size: int = 200,
):
# DDP
dist_init(rank, world_size)
# Setup
model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)
ddp = ShardedDataParallel(
module=model, optimizer=torch.optim.SGD, optimizer_params={"lr": 1e-4, "momentum": 0.9}, world_size=world_size
)
optimizer = ddp.optimizer
# Reset the memory use counter # Reset the memory use counter
torch.cuda.reset_peak_memory_stats(rank) torch.cuda.reset_peak_memory_stats(rank)
# Dummy training loop
torch.cuda.synchronize(rank)
training_start = time.monotonic()
model.train()
measurements = []
for epoch in range(num_epochs):
epoch_start = time.monotonic()
for batch in dataloader:
def closure():
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
loss /= world_size
loss.backward()
if dist.get_rank() == 0:
print(f"Loss: {loss.item()}")
ddp.reduce() # Send the gradients to the appropriate shards
return loss
optimizer.step(closure)
epoch_end = time.monotonic()
measurements.append(data_size / (epoch_end - epoch_start))
if dist.get_rank() == 0:
print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec")
torch.cuda.synchronize(rank)
training_stop = time.monotonic()
img_per_sec = data_size / (training_stop - training_start) * num_epochs
max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20
print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")
# Compute the mean and average img per second
mean = sum(measurements) / len(measurements)
diff = map(lambda x: pow(x - mean, 2.0), measurements)
std = math.sqrt(sum(diff) / (len(measurements) - 1))
print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")
def train(
rank: int,
world_size: int,
num_epochs: int = 10,
batch_size: int = 32,
data_size: int = 200,
use_oss: bool = True,
check_regression: bool = True,
reference_speed: float = -1.0,
reference_memory: float = -1.0,
):
# DDP
dist_init(rank, world_size)
# Setup
model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)
# Shard the optimizer # Shard the optimizer
optimizer: torch.optim.Optimizer = OSS( optimizer: torch.optim.Optimizer = (
params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9 OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
) if use_oss else OPTIM(model.parameters(), lr=1e-4, momentum=0.9) if use_oss
else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
)
# Reset the memory use counter
torch.cuda.reset_peak_memory_stats(rank)
# Dummy training loop # Dummy training loop
torch.cuda.synchronize(rank) torch.cuda.synchronize(rank)
...@@ -95,9 +169,9 @@ def train( ...@@ -95,9 +169,9 @@ def train(
# Check the checkpointing in the case of the OSS optimizer # Check the checkpointing in the case of the OSS optimizer
# Memory usage could spill over from there # Memory usage could spill over from there
optimizer = cast(OSS, optimizer) optimizer = cast(OSS, optimizer)
# optimizer.consolidate_state_dict() optimizer.consolidate_state_dict()
if dist.get_rank() == 0: if dist.get_rank() == 0:
# _ = optimizer.state_dict() _ = optimizer.state_dict()
print("... State dict collected") print("... State dict collected")
measurements.append(data_size / (epoch_end - epoch_start)) measurements.append(data_size / (epoch_end - epoch_start))
...@@ -137,30 +211,42 @@ if __name__ == "__main__": ...@@ -137,30 +211,42 @@ if __name__ == "__main__":
parser.add_argument("--reference_speed", action="store", default=32.32, type=float) parser.add_argument("--reference_speed", action="store", default=32.32, type=float)
parser.add_argument("--reference_memory", action="store", default=4475, type=float) parser.add_argument("--reference_memory", action="store", default=4475, type=float)
# beta - test oss_ddp
parser.add_argument("--oss_ddp", action="store_true", default=False)
args = parser.parse_args() args = parser.parse_args()
print(f"Benchmark arguments: {args}") print(f"Benchmark arguments: {args}")
print("\nBenchmark vanilla optimizer") if args.oss_ddp:
mp.spawn( print("\nBenchmark OSS DDP")
train, mp.spawn(
args=(args.world_size, args.epochs, args.batch_size, args.data_size, False, False), train_oss_ddp,
nprocs=args.world_size, args=(args.world_size, args.epochs, args.batch_size, args.data_size),
join=True, nprocs=args.world_size,
) join=True,
)
print("\nBenchmark OSS") else:
mp.spawn( print("\nBenchmark vanilla optimizer")
train, mp.spawn(
args=( train,
args.world_size, args=(args.world_size, args.epochs, args.batch_size, args.data_size, False, False),
args.epochs, nprocs=args.world_size,
args.batch_size, join=True,
args.data_size, )
True,
args.check_regression, print("\nBenchmark OSS")
args.reference_speed, mp.spawn(
args.reference_memory, train,
), args=(
nprocs=args.world_size, args.world_size,
join=True, args.epochs,
) args.batch_size,
args.data_size,
True,
args.check_regression,
args.reference_speed,
args.reference_memory,
),
nprocs=args.world_size,
join=True,
)
...@@ -3,4 +3,4 @@ ...@@ -3,4 +3,4 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .oss_ddp import OssDdp from .sharded_ddp import ShardedDataParallel
...@@ -9,26 +9,19 @@ A distributed data parallel class that works with OSS optimizer. ...@@ -9,26 +9,19 @@ A distributed data parallel class that works with OSS optimizer.
Adopted from LegacyDistributedDataParallel module from fairseq. Adopted from LegacyDistributedDataParallel module from fairseq.
""" """
from collections import OrderedDict
from contextlib import contextmanager from contextlib import contextmanager
import copy import copy
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, cast from typing import Any, Dict, Generator, List, Optional, Type, cast
import torch import torch
from torch import nn from torch import Tensor, nn
import torch.distributed as dist import torch.distributed as dist
from torch.nn import Parameter
if TYPE_CHECKING: from fairscale.optim import OSS
from fairscale.optim import OSS
from torch import Tensor
from torch.nn import Parameter
else:
OSS = Any
Tensor = Any
Parameter = Any
class OssDdp(nn.Module): class ShardedDataParallel(nn.Module):
"""Implements distributed data parallel training with optimizer state sharding. """Implements distributed data parallel training with optimizer state sharding.
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`. A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
...@@ -37,7 +30,8 @@ class OssDdp(nn.Module): ...@@ -37,7 +30,8 @@ class OssDdp(nn.Module):
Args: Args:
module (~torch.nn.Module): module to be parallelized module (~torch.nn.Module): module to be parallelized
oss (fairscale.optim.OSS): shared state optimizer optimizer (~torch.optim.Optimizer): optimizer to be used for training
optimizer_params(Dict): extra parameters for the optimizer
world_size (int): number of parallel workers world_size (int): number of parallel workers
process_group (optional): the c10d process group to be used for process_group (optional): the c10d process group to be used for
distributed gradient reduction. If None, the default WORLD process group distributed gradient reduction. If None, the default WORLD process group
...@@ -48,7 +42,13 @@ class OssDdp(nn.Module): ...@@ -48,7 +42,13 @@ class OssDdp(nn.Module):
""" """
def __init__( def __init__(
self, module: nn.Module, oss: OSS, world_size: int, process_group: Any = None, buffer_size: int = 2 ** 28 self,
module: nn.Module,
optimizer: Type[torch.optim.Optimizer],
optimizer_params: Dict[str, Any],
world_size: int,
process_group: Any = None,
buffer_size: int = 2 ** 28,
): ):
super().__init__() super().__init__()
...@@ -68,38 +68,25 @@ class OssDdp(nn.Module): ...@@ -68,38 +68,25 @@ class OssDdp(nn.Module):
# gradients-reduce at some later time # gradients-reduce at some later time
self.accumulate_grads = False self.accumulate_grads = False
# TODO (Min): The algorithm here can be improved. We are sorting params by device # Build the sharded optimizer
# and by rank. Then in reduction_fn below, we pack smaller ones into self.sharded_optimizer = OSS(self.module.parameters(), optim=optimizer, group=process_group, **optimizer_params)
# a buffer for reduction.
# We can pre-sort them here and simplify the reduction_fn logic below
# since their size shouldn't change.
# make per-device lists of parameters
paramlists: OrderedDict = OrderedDict()
for param in self.module.parameters():
device = param.device
if paramlists.get(device) is None:
paramlists[device] = []
paramlists[device] += [param]
self.per_device_params = list(paramlists.values())
# query oss and build a param-to-rank table
self.param_rank = {}
for rank, param_groups in enumerate(oss.partition_parameters()):
for param_group in param_groups:
for param in param_group["params"]:
self.param_rank[param] = rank
# sanity checks # sanity checks
assert len(self.param_rank) == len(list(self.module.parameters())), "number of params do not match" assert len(self.sharded_optimizer.param_to_rank) == len(
list(self.module.parameters())
), "number of params do not match"
for param in self.module.parameters(): for param in self.module.parameters():
assert param in self.param_rank, f"{param} not in the optimizer" assert param in self.sharded_optimizer.param_to_rank, f"{param} not in the optimizer"
def __getstate__(self) -> Dict: def __getstate__(self) -> Dict:
attrs = copy.copy(self.__dict__) attrs = copy.copy(self.__dict__)
return attrs return attrs
def train(self, mode: bool = True) -> "OssDdp": @property
def optimizer(self) -> torch.optim.Optimizer:
return self.sharded_optimizer
def train(self, mode: bool = True) -> "ShardedDataParallel":
pre_mode = self.module.training pre_mode = self.module.training
self.module.train(mode) self.module.train(mode)
if self.module.training: if self.module.training:
...@@ -176,10 +163,9 @@ class OssDdp(nn.Module): ...@@ -176,10 +163,9 @@ class OssDdp(nn.Module):
p.grad = buffer[offset : offset + sz].view_as(p).clone() p.grad = buffer[offset : offset + sz].view_as(p).clone()
offset += sz offset += sz
else: else:
# zero the grads # wipe the grads
for p in params: for p in params:
if p.grad is not None: p.grad = None
p.grad.data.zero_()
def reduction_fn() -> None: def reduction_fn() -> None:
# This function only needs to be called once # This function only needs to be called once
...@@ -190,16 +176,17 @@ class OssDdp(nn.Module): ...@@ -190,16 +176,17 @@ class OssDdp(nn.Module):
if self.buffer is None: if self.buffer is None:
self.buffer = next(self.module.parameters()).new(self.buffer_size) # type: ignore self.buffer = next(self.module.parameters()).new(self.buffer_size) # type: ignore
for params in self.per_device_params: for params in self.sharded_optimizer.per_device_params:
# Reduce the gradients in buckets # Reduce the gradients in buckets
offset = 0 offset = 0
buffered_params: List[Parameter] = [] buffered_params: List[Parameter] = []
param_rank: Optional[int] = None param_rank: Optional[int] = None
for param in params: for param in params:
last_param_rank: Optional[int] = param_rank last_param_rank: Optional[int] = param_rank
param_rank = self.param_rank[param] param_rank = self.sharded_optimizer.param_to_rank[param]
if not param.requires_grad: if not param.requires_grad:
continue continue
if param.grad is None: if param.grad is None:
param.grad = torch.zeros_like(param) param.grad = torch.zeros_like(param)
if param.grad.requires_grad: if param.grad.requires_grad:
...@@ -219,7 +206,7 @@ class OssDdp(nn.Module): ...@@ -219,7 +206,7 @@ class OssDdp(nn.Module):
reduce_params(buffered_params, cast(int, last_param_rank)) reduce_params(buffered_params, cast(int, last_param_rank))
offset = 0 offset = 0
buffered_params.clear() buffered_params.clear()
buffered_params.append(param) buffered_params.append(cast(Parameter, param))
offset += sz offset += sz
if len(buffered_params) > 0: if len(buffered_params) > 0:
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import copy import copy
from itertools import chain from itertools import chain
import logging import logging
...@@ -10,6 +11,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Ty ...@@ -10,6 +11,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Ty
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.nn import Parameter
from torch.optim import SGD, Optimizer from torch.optim import SGD, Optimizer
from .utils import broadcast_object, recursive_copy_to_device from .utils import broadcast_object, recursive_copy_to_device
...@@ -51,23 +53,29 @@ class OSS(Optimizer): ...@@ -51,23 +53,29 @@ class OSS(Optimizer):
in_super_constructor: bool in_super_constructor: bool
def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Any = dist.group.WORLD, **defaults: Any): def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Optional[Any] = None, **default: Any):
# Hold all the model params in the root .param_groups # Hold all the model params in the root .param_groups
self.in_super_constructor = True self.in_super_constructor = True
super().__init__(params, defaults) super().__init__(params, default)
self.in_super_constructor = False self.in_super_constructor = False
# Partition information. lazy evaluation, computed if requested
self._per_device_params: List[List[Parameter]] = []
self._param_rank: Dict[torch.Tensor, int] = {}
self._partition_parameters: List[List[dict]] = []
# Build the wrapped optimizer, responsible for a shard of the params # Build the wrapped optimizer, responsible for a shard of the params
self.group = group self.group = group if group is not None else dist.group.WORLD
self.rank = dist.get_rank(group) self.world_size = dist.get_world_size(self.group)
split_param_groups = self.partition_parameters()
self.optim = optim(split_param_groups[self.rank], **defaults) self.rank = dist.get_rank(self.group)
self.optim = optim(self.partition_parameters()[self.rank], **default)
# Optional consolidated optimizer state # Optional consolidated optimizer state
self._all_states: List[Dict[str, Any]] = [] self._all_states: List[Dict[str, Any]] = []
# Current device is set by the parameters allocated to this rank # Current device is set by the parameters allocated to this rank
self._device = split_param_groups[self.rank][0]["params"][0].device self._device = self.partition_parameters()[self.rank][0]["params"][0].device
# Sync local and global param_groups keys # Sync local and global param_groups keys
for global_group, local_group in zip(self.param_groups, self.optim.param_groups): for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
...@@ -75,6 +83,7 @@ class OSS(Optimizer): ...@@ -75,6 +83,7 @@ class OSS(Optimizer):
if k != "params": if k != "params":
global_group[k] = v global_group[k] = v
# Partition helpers
def partition_parameters(self) -> List[List[dict]]: def partition_parameters(self) -> List[List[dict]]:
"""Partitions parameters across distributed ranks. """Partitions parameters across distributed ranks.
...@@ -83,21 +92,52 @@ class OSS(Optimizer): ...@@ -83,21 +92,52 @@ class OSS(Optimizer):
corresponds to rank 0, etc. We need all the ranks for the broadcast corresponds to rank 0, etc. We need all the ranks for the broadcast
inside step(). inside step().
""" """
world_size = dist.get_world_size(self.group) if len(self._partition_parameters) == 0:
param_groups: List[List] = [list() for _ in range(world_size)] self._partition_parameters = [list() for _ in range(self.world_size)]
sizes = [0] * world_size sizes = [0] * self.world_size
for param_group in self.param_groups: for param_group in self.param_groups:
param_lists: List[List] = [list() for _ in range(world_size)] param_lists: List[List] = [list() for _ in range(self.world_size)]
for param in param_group["params"]: for param in param_group["params"]:
# Add this param to rank with smallest size. # Add this param to rank with smallest size.
rank = sizes.index(min(sizes)) rank = sizes.index(min(sizes))
param_lists[rank].append(param) param_lists[rank].append(param)
sizes[rank] += param.numel() sizes[rank] += param.numel()
for rank, params in enumerate(param_lists):
param_group_rank = copy.copy(param_group) for rank, params in enumerate(param_lists):
param_group_rank["params"] = params param_group_rank = copy.copy(param_group)
param_groups[rank].append(param_group_rank) param_group_rank["params"] = params
return param_groups self._partition_parameters[rank].append(param_group_rank)
return self._partition_parameters
@property
def per_device_params(self) -> List[List[Parameter]]:
# TODO (Min): The algorithm here can be improved. We are sorting params by device
# and by rank. Then in reduction_fn below, we pack smaller ones into
# a buffer for reduction.
# We can pre-sort them here and simplify the reduction_fn logic below
# since their size shouldn't change.
if len(self._per_device_params) == 0:
for param_group in self.param_groups:
param_lists: OrderedDict = OrderedDict()
for param in param_group["params"]:
device = param.device
if param_lists.get(device) is None:
param_lists[device] = []
param_lists[device] += [param]
self._per_device_params = list(param_lists.values())
return self._per_device_params
@property
def param_to_rank(self) -> Dict[torch.Tensor, int]:
if len(self._param_rank) == 0:
for rank, param_groups in enumerate(self.partition_parameters()):
for param_group in param_groups:
for param in param_group["params"]:
self._param_rank[param] = rank
return self._param_rank
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs. # NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs. # For example, the apex library contains fused optimizers with a step that supports extra kwargs.
...@@ -218,6 +258,8 @@ class OSS(Optimizer): ...@@ -218,6 +258,8 @@ class OSS(Optimizer):
def add_param_group(self, param_group: dict) -> None: def add_param_group(self, param_group: dict) -> None:
super().add_param_group(param_group) super().add_param_group(param_group)
if not self.in_super_constructor: if not self.in_super_constructor:
self._partition_parameters.clear() # Force a re-partitioning
param_groups = self.partition_parameters()[self.rank] param_groups = self.partition_parameters()[self.rank]
if len(param_groups) == len(self.optim.param_groups) + 1: if len(param_groups) == len(self.optim.param_groups) + 1:
self.optim.add_param_group(param_groups[-1]) self.optim.add_param_group(param_groups[-1])
......
...@@ -15,8 +15,7 @@ import torch.distributed as dist ...@@ -15,8 +15,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.nn import Linear, Sequential from torch.nn import Linear, Sequential
from fairscale.nn.data_parallel import OssDdp from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs required") skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs required")
...@@ -39,16 +38,30 @@ def run_one_step(rank, world_size, backend, device, temp_file_name): ...@@ -39,16 +38,30 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
model = Sequential(Linear(2, 3), Linear(3, 4)).to(device) model = Sequential(Linear(2, 3), Linear(3, 4)).to(device)
optimizer = OSS(model.parameters(), lr=0.1, momentum=0.99)
ddp = OssDdp(model, optimizer, world_size) ddp = ShardedDataParallel(
module=model, optimizer=torch.optim.SGD, optimizer_params={"lr": 0.1, "momentum": 0.99}, world_size=world_size
)
optimizer = ddp.optimizer
input_tensor = torch.rand((64, 2)).to(device) input_tensor = torch.rand((64, 2)).to(device)
output = ddp(input_tensor).sum() output = ddp(input_tensor).abs().sum() / input_tensor.numel()
output.backward() output.backward()
ddp.reduce() ddp.reduce()
# Check that all the grads have been populated, for the shard
if device == torch.device("cuda"):
torch.cuda.synchronize() # flush any remaining cuda op, just in case
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
if param.requires_grad:
assert param.grad.abs().sum().item() > 0.0, "The reduce step should have populated all the gradients"
# Check that the optimization process makes sense (ie. loss goes down for the same data)
optimizer.step() optimizer.step()
# TODO (Min): I need to figure out a way to verify the grads are reduced correctly new_eval = ddp(input_tensor).abs().sum() / input_tensor.numel()
# between the ranks. I haven't found the best way yet. Will need to come # assert new_eval.item() < output.item()
# back here before this is used in real training.
def run_test(backend, device, world_size=2): def run_test(backend, device, world_size=2):
...@@ -62,8 +75,9 @@ def run_eval_mode(_unused): ...@@ -62,8 +75,9 @@ def run_eval_mode(_unused):
init_method=f"file://{tempfile.mkstemp()[1]}", backend=dist.Backend.GLOO, rank=0, world_size=1 init_method=f"file://{tempfile.mkstemp()[1]}", backend=dist.Backend.GLOO, rank=0, world_size=1
) )
model = Sequential(Linear(2, 3), Linear(3, 4)) model = Sequential(Linear(2, 3), Linear(3, 4))
optimizer = OSS(model.parameters(), lr=0.1, momentum=0.99) optimizer_params = {"lr": 0.1, "momentum": 0.99}
ddp = OssDdp(model, optimizer, 1) ddp = ShardedDataParallel(model, torch.optim.SGD, optimizer_params, 1)
optimizer = ddp.optimizer
ddp.eval() ddp.eval()
for _ in range(5): for _ in range(5):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment