Unverified Commit 49a198c9 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] Sharded DDP - small refactor and new features (#97)

- rename oss_ddp to ShardedDataParallel
- some refactoring
- ShardedDataParallel owns the sharded optimizer, exposed if need be
- some small perf bumps
parent 2ddce57f
......@@ -102,6 +102,12 @@ run_oss_benchmark: &run_oss_benchmark
command: |
python benchmarks/oss.py
run_oss_ddp_benchmark: &run_oss_ddp_benchmark
- run:
name: Run OSS DDP Benchmark
command: |
python benchmarks/oss.py --oss_ddp
# -------------------------------------------------------------------------------------
# Jobs to run
# -------------------------------------------------------------------------------------
......@@ -252,6 +258,8 @@ jobs:
- <<: *run_oss_benchmark
- <<: *run_oss_ddp_benchmark
workflows:
......
......@@ -17,7 +17,7 @@ repos:
- id: end-of-file-fixer
- repo: https://github.com/ambv/black
rev: stable
rev: 19.10b0
hooks:
- id: black
language_version: python3.6
......
......@@ -16,6 +16,7 @@ from torchvision.datasets import FakeData
from torchvision.models import resnet101
from torchvision.transforms import ToTensor
from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim.oss import OSS
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
......@@ -28,21 +29,7 @@ def dist_init(rank, world_size):
dist.init_process_group(backend=BACKEND, rank=rank, world_size=world_size)
def train(
rank: int,
world_size: int,
num_epochs: int = 10,
batch_size: int = 32,
data_size: int = 200,
use_oss: bool = True,
check_regression: bool = True,
reference_speed: float = -1.0,
reference_memory: float = -1.0,
):
# DDP
dist_init(rank, world_size)
def get_problem(rank, data_size, batch_size):
# Standard RN101
model = resnet101(pretrained=False, progress=True).to(rank)
......@@ -57,14 +44,101 @@ def train(
dataset=FakeData(transform=ToTensor(), size=data_size), batch_size=batch_size, collate_fn=collate
)
loss_fn = nn.CrossEntropyLoss()
return model, dataloader, loss_fn
def train_oss_ddp(
rank: int, world_size: int, num_epochs: int = 10, batch_size: int = 32, data_size: int = 200,
):
# DDP
dist_init(rank, world_size)
# Setup
model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)
ddp = ShardedDataParallel(
module=model, optimizer=torch.optim.SGD, optimizer_params={"lr": 1e-4, "momentum": 0.9}, world_size=world_size
)
optimizer = ddp.optimizer
# Reset the memory use counter
torch.cuda.reset_peak_memory_stats(rank)
# Dummy training loop
torch.cuda.synchronize(rank)
training_start = time.monotonic()
model.train()
measurements = []
for epoch in range(num_epochs):
epoch_start = time.monotonic()
for batch in dataloader:
def closure():
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
loss /= world_size
loss.backward()
if dist.get_rank() == 0:
print(f"Loss: {loss.item()}")
ddp.reduce() # Send the gradients to the appropriate shards
return loss
optimizer.step(closure)
epoch_end = time.monotonic()
measurements.append(data_size / (epoch_end - epoch_start))
if dist.get_rank() == 0:
print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec")
torch.cuda.synchronize(rank)
training_stop = time.monotonic()
img_per_sec = data_size / (training_stop - training_start) * num_epochs
max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20
print(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall")
print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")
# Compute the mean and average img per second
mean = sum(measurements) / len(measurements)
diff = map(lambda x: pow(x - mean, 2.0), measurements)
std = math.sqrt(sum(diff) / (len(measurements) - 1))
print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")
def train(
rank: int,
world_size: int,
num_epochs: int = 10,
batch_size: int = 32,
data_size: int = 200,
use_oss: bool = True,
check_regression: bool = True,
reference_speed: float = -1.0,
reference_memory: float = -1.0,
):
# DDP
dist_init(rank, world_size)
# Setup
model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)
# Shard the optimizer
optimizer: torch.optim.Optimizer = OSS(
params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9
) if use_oss else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
optimizer: torch.optim.Optimizer = (
OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
if use_oss
else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
)
# Reset the memory use counter
torch.cuda.reset_peak_memory_stats(rank)
# Dummy training loop
torch.cuda.synchronize(rank)
......@@ -95,9 +169,9 @@ def train(
# Check the checkpointing in the case of the OSS optimizer
# Memory usage could spill over from there
optimizer = cast(OSS, optimizer)
# optimizer.consolidate_state_dict()
optimizer.consolidate_state_dict()
if dist.get_rank() == 0:
# _ = optimizer.state_dict()
_ = optimizer.state_dict()
print("... State dict collected")
measurements.append(data_size / (epoch_end - epoch_start))
......@@ -137,9 +211,21 @@ if __name__ == "__main__":
parser.add_argument("--reference_speed", action="store", default=32.32, type=float)
parser.add_argument("--reference_memory", action="store", default=4475, type=float)
# beta - test oss_ddp
parser.add_argument("--oss_ddp", action="store_true", default=False)
args = parser.parse_args()
print(f"Benchmark arguments: {args}")
if args.oss_ddp:
print("\nBenchmark OSS DDP")
mp.spawn(
train_oss_ddp,
args=(args.world_size, args.epochs, args.batch_size, args.data_size),
nprocs=args.world_size,
join=True,
)
else:
print("\nBenchmark vanilla optimizer")
mp.spawn(
train,
......
......@@ -3,4 +3,4 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from .oss_ddp import OssDdp
from .sharded_ddp import ShardedDataParallel
......@@ -9,26 +9,19 @@ A distributed data parallel class that works with OSS optimizer.
Adopted from LegacyDistributedDataParallel module from fairseq.
"""
from collections import OrderedDict
from contextlib import contextmanager
import copy
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, cast
from typing import Any, Dict, Generator, List, Optional, Type, cast
import torch
from torch import nn
from torch import Tensor, nn
import torch.distributed as dist
from torch.nn import Parameter
if TYPE_CHECKING:
from fairscale.optim import OSS
from torch import Tensor
from torch.nn import Parameter
else:
OSS = Any
Tensor = Any
Parameter = Any
from fairscale.optim import OSS
class OssDdp(nn.Module):
class ShardedDataParallel(nn.Module):
"""Implements distributed data parallel training with optimizer state sharding.
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
......@@ -37,7 +30,8 @@ class OssDdp(nn.Module):
Args:
module (~torch.nn.Module): module to be parallelized
oss (fairscale.optim.OSS): shared state optimizer
optimizer (~torch.optim.Optimizer): optimizer to be used for training
optimizer_params(Dict): extra parameters for the optimizer
world_size (int): number of parallel workers
process_group (optional): the c10d process group to be used for
distributed gradient reduction. If None, the default WORLD process group
......@@ -48,7 +42,13 @@ class OssDdp(nn.Module):
"""
def __init__(
self, module: nn.Module, oss: OSS, world_size: int, process_group: Any = None, buffer_size: int = 2 ** 28
self,
module: nn.Module,
optimizer: Type[torch.optim.Optimizer],
optimizer_params: Dict[str, Any],
world_size: int,
process_group: Any = None,
buffer_size: int = 2 ** 28,
):
super().__init__()
......@@ -68,38 +68,25 @@ class OssDdp(nn.Module):
# gradients-reduce at some later time
self.accumulate_grads = False
# TODO (Min): The algorithm here can be improved. We are sorting params by device
# and by rank. Then in reduction_fn below, we pack smaller ones into
# a buffer for reduction.
# We can pre-sort them here and simplify the reduction_fn logic below
# since their size shouldn't change.
# make per-device lists of parameters
paramlists: OrderedDict = OrderedDict()
for param in self.module.parameters():
device = param.device
if paramlists.get(device) is None:
paramlists[device] = []
paramlists[device] += [param]
self.per_device_params = list(paramlists.values())
# query oss and build a param-to-rank table
self.param_rank = {}
for rank, param_groups in enumerate(oss.partition_parameters()):
for param_group in param_groups:
for param in param_group["params"]:
self.param_rank[param] = rank
# Build the sharded optimizer
self.sharded_optimizer = OSS(self.module.parameters(), optim=optimizer, group=process_group, **optimizer_params)
# sanity checks
assert len(self.param_rank) == len(list(self.module.parameters())), "number of params do not match"
assert len(self.sharded_optimizer.param_to_rank) == len(
list(self.module.parameters())
), "number of params do not match"
for param in self.module.parameters():
assert param in self.param_rank, f"{param} not in the optimizer"
assert param in self.sharded_optimizer.param_to_rank, f"{param} not in the optimizer"
def __getstate__(self) -> Dict:
attrs = copy.copy(self.__dict__)
return attrs
def train(self, mode: bool = True) -> "OssDdp":
@property
def optimizer(self) -> torch.optim.Optimizer:
return self.sharded_optimizer
def train(self, mode: bool = True) -> "ShardedDataParallel":
pre_mode = self.module.training
self.module.train(mode)
if self.module.training:
......@@ -176,10 +163,9 @@ class OssDdp(nn.Module):
p.grad = buffer[offset : offset + sz].view_as(p).clone()
offset += sz
else:
# zero the grads
# wipe the grads
for p in params:
if p.grad is not None:
p.grad.data.zero_()
p.grad = None
def reduction_fn() -> None:
# This function only needs to be called once
......@@ -190,16 +176,17 @@ class OssDdp(nn.Module):
if self.buffer is None:
self.buffer = next(self.module.parameters()).new(self.buffer_size) # type: ignore
for params in self.per_device_params:
for params in self.sharded_optimizer.per_device_params:
# Reduce the gradients in buckets
offset = 0
buffered_params: List[Parameter] = []
param_rank: Optional[int] = None
for param in params:
last_param_rank: Optional[int] = param_rank
param_rank = self.param_rank[param]
param_rank = self.sharded_optimizer.param_to_rank[param]
if not param.requires_grad:
continue
if param.grad is None:
param.grad = torch.zeros_like(param)
if param.grad.requires_grad:
......@@ -219,7 +206,7 @@ class OssDdp(nn.Module):
reduce_params(buffered_params, cast(int, last_param_rank))
offset = 0
buffered_params.clear()
buffered_params.append(param)
buffered_params.append(cast(Parameter, param))
offset += sz
if len(buffered_params) > 0:
......
......@@ -3,6 +3,7 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import copy
from itertools import chain
import logging
......@@ -10,6 +11,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Ty
import torch
import torch.distributed as dist
from torch.nn import Parameter
from torch.optim import SGD, Optimizer
from .utils import broadcast_object, recursive_copy_to_device
......@@ -51,23 +53,29 @@ class OSS(Optimizer):
in_super_constructor: bool
def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Any = dist.group.WORLD, **defaults: Any):
def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Optional[Any] = None, **default: Any):
# Hold all the model params in the root .param_groups
self.in_super_constructor = True
super().__init__(params, defaults)
super().__init__(params, default)
self.in_super_constructor = False
# Partition information. lazy evaluation, computed if requested
self._per_device_params: List[List[Parameter]] = []
self._param_rank: Dict[torch.Tensor, int] = {}
self._partition_parameters: List[List[dict]] = []
# Build the wrapped optimizer, responsible for a shard of the params
self.group = group
self.rank = dist.get_rank(group)
split_param_groups = self.partition_parameters()
self.optim = optim(split_param_groups[self.rank], **defaults)
self.group = group if group is not None else dist.group.WORLD
self.world_size = dist.get_world_size(self.group)
self.rank = dist.get_rank(self.group)
self.optim = optim(self.partition_parameters()[self.rank], **default)
# Optional consolidated optimizer state
self._all_states: List[Dict[str, Any]] = []
# Current device is set by the parameters allocated to this rank
self._device = split_param_groups[self.rank][0]["params"][0].device
self._device = self.partition_parameters()[self.rank][0]["params"][0].device
# Sync local and global param_groups keys
for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
......@@ -75,6 +83,7 @@ class OSS(Optimizer):
if k != "params":
global_group[k] = v
# Partition helpers
def partition_parameters(self) -> List[List[dict]]:
"""Partitions parameters across distributed ranks.
......@@ -83,21 +92,52 @@ class OSS(Optimizer):
corresponds to rank 0, etc. We need all the ranks for the broadcast
inside step().
"""
world_size = dist.get_world_size(self.group)
param_groups: List[List] = [list() for _ in range(world_size)]
sizes = [0] * world_size
if len(self._partition_parameters) == 0:
self._partition_parameters = [list() for _ in range(self.world_size)]
sizes = [0] * self.world_size
for param_group in self.param_groups:
param_lists: List[List] = [list() for _ in range(world_size)]
param_lists: List[List] = [list() for _ in range(self.world_size)]
for param in param_group["params"]:
# Add this param to rank with smallest size.
rank = sizes.index(min(sizes))
param_lists[rank].append(param)
sizes[rank] += param.numel()
for rank, params in enumerate(param_lists):
param_group_rank = copy.copy(param_group)
param_group_rank["params"] = params
param_groups[rank].append(param_group_rank)
return param_groups
self._partition_parameters[rank].append(param_group_rank)
return self._partition_parameters
@property
def per_device_params(self) -> List[List[Parameter]]:
# TODO (Min): The algorithm here can be improved. We are sorting params by device
# and by rank. Then in reduction_fn below, we pack smaller ones into
# a buffer for reduction.
# We can pre-sort them here and simplify the reduction_fn logic below
# since their size shouldn't change.
if len(self._per_device_params) == 0:
for param_group in self.param_groups:
param_lists: OrderedDict = OrderedDict()
for param in param_group["params"]:
device = param.device
if param_lists.get(device) is None:
param_lists[device] = []
param_lists[device] += [param]
self._per_device_params = list(param_lists.values())
return self._per_device_params
@property
def param_to_rank(self) -> Dict[torch.Tensor, int]:
if len(self._param_rank) == 0:
for rank, param_groups in enumerate(self.partition_parameters()):
for param_group in param_groups:
for param in param_group["params"]:
self._param_rank[param] = rank
return self._param_rank
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs.
......@@ -218,6 +258,8 @@ class OSS(Optimizer):
def add_param_group(self, param_group: dict) -> None:
super().add_param_group(param_group)
if not self.in_super_constructor:
self._partition_parameters.clear() # Force a re-partitioning
param_groups = self.partition_parameters()[self.rank]
if len(param_groups) == len(self.optim.param_groups) + 1:
self.optim.add_param_group(param_groups[-1])
......
......@@ -15,8 +15,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn import Linear, Sequential
from fairscale.nn.data_parallel import OssDdp
from fairscale.optim import OSS
from fairscale.nn.data_parallel import ShardedDataParallel
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs required")
......@@ -39,16 +38,30 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
torch.cuda.set_device(rank)
model = Sequential(Linear(2, 3), Linear(3, 4)).to(device)
optimizer = OSS(model.parameters(), lr=0.1, momentum=0.99)
ddp = OssDdp(model, optimizer, world_size)
ddp = ShardedDataParallel(
module=model, optimizer=torch.optim.SGD, optimizer_params={"lr": 0.1, "momentum": 0.99}, world_size=world_size
)
optimizer = ddp.optimizer
input_tensor = torch.rand((64, 2)).to(device)
output = ddp(input_tensor).sum()
output = ddp(input_tensor).abs().sum() / input_tensor.numel()
output.backward()
ddp.reduce()
# Check that all the grads have been populated, for the shard
if device == torch.device("cuda"):
torch.cuda.synchronize() # flush any remaining cuda op, just in case
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
if param.requires_grad:
assert param.grad.abs().sum().item() > 0.0, "The reduce step should have populated all the gradients"
# Check that the optimization process makes sense (ie. loss goes down for the same data)
optimizer.step()
# TODO (Min): I need to figure out a way to verify the grads are reduced correctly
# between the ranks. I haven't found the best way yet. Will need to come
# back here before this is used in real training.
new_eval = ddp(input_tensor).abs().sum() / input_tensor.numel()
# assert new_eval.item() < output.item()
def run_test(backend, device, world_size=2):
......@@ -62,8 +75,9 @@ def run_eval_mode(_unused):
init_method=f"file://{tempfile.mkstemp()[1]}", backend=dist.Backend.GLOO, rank=0, world_size=1
)
model = Sequential(Linear(2, 3), Linear(3, 4))
optimizer = OSS(model.parameters(), lr=0.1, momentum=0.99)
ddp = OssDdp(model, optimizer, 1)
optimizer_params = {"lr": 0.1, "momentum": 0.99}
ddp = ShardedDataParallel(model, torch.optim.SGD, optimizer_params, 1)
optimizer = ddp.optimizer
ddp.eval()
for _ in range(5):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment