Unverified Commit ade312c4 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] OSS-aware clip grads, bridge sharded states (#167)

add a clip gradients util, equivalent to torch's but aware of the sharded states. Add a corresponding unit test
parent 2fe93203
......@@ -136,7 +136,7 @@ def train(
for batch in dataloader:
batch__start = time.monotonic()
def closure():
def closure(data=batch):
model.zero_grad()
if args.debug and rank == 0 and next(model.parameters()).grad is not None:
logging.debug(
......@@ -147,11 +147,11 @@ def train(
if not args.cpu and args.amp:
# Automatically computes the FW pass in half precision
with torch.cuda.amp.autocast():
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
outputs = model(data["inputs"])
loss = loss_fn(outputs, data["label"])
else:
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
outputs = model(data["inputs"])
loss = loss_fn(outputs, data["label"])
loss.backward()
......@@ -257,9 +257,9 @@ if __name__ == "__main__":
args = parser.parse_args()
logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
logging.info(f"Benchmark arguments: {args}")
logging.info("Benchmark arguments: %s" % args)
backend = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
BACKEND = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
# Download dataset once for all processes
dataset, tentatives = None, 0
......@@ -271,7 +271,7 @@ if __name__ == "__main__":
# Corrupted data, erase and restart
shutil.rmtree(TEMPDIR + "/MNIST")
logging.warning("Failed loading dataset: ", e)
logging.warning("Failed loading dataset: %s " % e)
tentatives += 1
if dataset is None:
......@@ -285,7 +285,7 @@ if __name__ == "__main__":
logging.info("\n*** Benchmark vanilla optimizer")
mp.spawn(
train,
args=(args, backend, OptimType.vanilla, False,), # no regression check
args=(args, BACKEND, OptimType.vanilla, False,), # no regression check
nprocs=args.world_size,
join=True,
)
......@@ -293,7 +293,7 @@ if __name__ == "__main__":
if args.optim_type == OptimType.oss_ddp or args.optim_type == OptimType.everyone:
logging.info("\n*** Benchmark OSS with DDP")
mp.spawn(
train, args=(args, backend, OptimType.oss_ddp, args.check_regression), nprocs=args.world_size, join=True,
train, args=(args, BACKEND, OptimType.oss_ddp, args.check_regression), nprocs=args.world_size, join=True,
)
if args.optim_type == OptimType.oss_sharded_ddp or args.optim_type == OptimType.everyone:
......@@ -302,7 +302,7 @@ if __name__ == "__main__":
train,
args=(
args,
backend,
BACKEND,
OptimType.oss_sharded_ddp,
False,
), # FIXME: @lefaudeux - SDP should give the same results
......
......@@ -84,12 +84,7 @@ def train(rank, args, model, device, train_loader, num_epochs):
model.zero_grad()
outputs = model(data)
loss = loss_fn(outputs, target)
loss /= WORLD_SIZE
loss.backward()
# if dist.get_rank() == 0:
# print(f"Loss: {loss.item()}")
ddp.reduce() # Send the gradients to the appropriate shards
return loss
......
import time
from typing import Optional, Union, cast
import torch
import torch.distributed as dist
......@@ -44,12 +45,14 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool):
dataloader = getData()
loss_fn = getLossFun()
base_optimizer_arguments = {"lr": 1e-4} # any optimizer specific arguments, LR, momentum, etc...
if ~use_oss:
optimizer = torch.optim.SGD(params=model.parameters(), **base_optimizer_arguments)
optimizer: Optional[Union[OSS, torch.optim.SGD]] = None
if not use_oss:
optimizer = torch.optim.SGD(params=model.parameters(), lr=1e-4)
else:
base_optimizer = torch.optim.SGD
optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments)
base_optimizer_arguments = {"lr": 1e-4} # any optimizer specific arguments, LR, momentum, etc...
optimizer = OSS(params=model.parameters(), optim=base_optimizer, default=base_optimizer_arguments)
training_start = time.monotonic()
# Any relevant training loop, nothing specific to OSS. For example:
......@@ -57,13 +60,24 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool):
for e in range(epochs):
for (data, target) in dataloader:
data, target = data.to(rank), target.to(rank)
# Train
model.zero_grad()
outputs = model(data)
loss = loss_fn(outputs, target)
loss /= world_size
loss.backward()
# if you want to clip the gradients / get the current max:
max_norm = 1000.0
norm_type = 1
if not use_oss:
_total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm, norm_type=norm_type) # type: ignore
else:
optimizer = cast(OSS, optimizer)
_total_norm = optimizer.clip_grad_norm(max_norm, norm_type=norm_type)
optimizer.step()
print(f"Loss: {loss.item()}")
training_end = time.monotonic()
......
......@@ -5,9 +5,11 @@
from collections import OrderedDict
import copy
import itertools
from itertools import chain
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type
from math import inf
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union
import torch
import torch.distributed as dist
......@@ -72,9 +74,7 @@ class OSS(Optimizer):
self.in_super_constructor = False
# Partition information. lazy evaluation, computed if requested
self._per_device_params: OrderedDict[
torch.device, List[List[Parameter]]
] = OrderedDict() # device, rank, params
self._per_device_params: Dict[torch.device, List[List[Parameter]]] = OrderedDict() # device, rank, params
self._param_rank: Dict[torch.Tensor, int] = {}
self._partition_parameters: List[List[dict]] = []
......@@ -201,6 +201,70 @@ class OSS(Optimizer):
return loss
def clip_grad_norm(self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0) -> torch.Tensor:
"""
Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
Arguments:
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
.. note: This is analogous to `torch.nn.utils.clip_grad_norm_` but handles the partitioning and multiple devices per rank
under the hood. The default torch util is not applicable here, because each rank only has a partial view of all the grads
in the model, so calling it in the OSS context would lead to different scaling being applied per subset of model parameters
.. warning: This needs to be called on all ranks, since synchronization primitives will be used
.. warning: Model paralelism -groups other than world- are not yet supported
"""
if self.group != dist.group.WORLD:
raise NotImplementedError("Clip norm not yet supported for model parallelism (coming soon!)")
# Compute the max norm for this shards's worth of gradients
max_norm = float(max_norm)
norm_type = float(norm_type)
# Filter out the grad-less params, concatenate params from all devices
local_params = itertools.chain(
*[
list(filter(lambda x: x.grad is not None, device_params[self.rank]))
for device_params in self.per_device_params.values()
]
)
# Compute the norm on this grad set,
# then sync all the norms from all ranks
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(self._device) for p in local_params) # type: ignore
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.group)
else:
local_norm = torch.norm(
input=torch.stack([torch.norm(input=p.grad.detach(), p=norm_type).to(self._device) for p in local_params]), # type: ignore
p=norm_type,
)
# local norm result can be accumulated with the remote ones if put to the right power
# n_i = sum_rank(a^p)^1/p
# -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p
total_norm = local_norm ** norm_type
dist.all_reduce(total_norm, group=self.group)
total_norm = total_norm ** (1.0 / norm_type)
clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)
if clip_coef < 1:
for device, device_params in self.per_device_params.items():
for p in filter(lambda x: x.grad is not None, device_params[self.rank]):
p.grad.detach().mul_(clip_coef.to(device)) # type: ignore
return total_norm
# State dict interfaces
def local_state_dict(self) -> dict:
"""Gets this rank's state_dict.
......@@ -345,6 +409,14 @@ class OSS(Optimizer):
if len(param_groups) == len(self.optim.param_groups) + 1:
self.optim.add_param_group(param_groups[-1])
@staticmethod
def get_global_rank(group: Any, rank: int) -> int:
if group is dist.group.WORLD:
return rank
else:
global_rank = dist.distributed_c10d._get_global_rank(group, rank)
return global_rank
def _sync_param_groups(self, local_to_global: bool = False) -> None:
"""Sync learning rate and other optimizer attributes (needed to support schedulers).
If the global param groups have been altered, and we want to make sure that the
......@@ -417,14 +489,6 @@ class OSS(Optimizer):
for t in p["params"]:
t.grad = None
@staticmethod
def get_global_rank(group: Any, rank: int) -> int:
if group is dist.group.WORLD:
return rank
else:
global_rank = dist.distributed_c10d._get_global_rank(group, rank)
return global_rank
def _broadcast_params(self, buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]]) -> None:
"""Helper function to broadcast all the parameters from a given device"""
buffer_size = buffers[0].numel()
......
......@@ -600,6 +600,7 @@ class Tensor:
def mode(self, dim: _int=-1, keepdim: _bool=False) -> Tuple[Tensor, Tensor]: ...
@overload
def mode(self, dim: Union[str, None], keepdim: _bool=False) -> Tuple[Tensor, Tensor]: ...
def mul_(self, value: Union[_float, _int, Tensor]): ...
def multinomial(self, num_samples: _int, replacement: _bool=False, *, generator: Generator=None) -> Tensor: ...
def mv(self, vec: Tensor) -> Tensor: ...
def mvlgamma(self, p: _int) -> Tensor: ...
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from . import Tensor
from typing import Tuple, List, Union
from typing import Tuple, List, Union, Optional, Any
def split(tensor: Tensor, split_size_or_sections: Union[int, List[int]], dim: int=0) -> Tuple[Tensor,...]: ...
def einsum(equation: str, *operands: Tensor): ...
def norm(input: Tensor, p: Union[int, float, Any], dim: Optional[List[int]]=None, keep_dim: Optional[bool]=False, out: Optional[Tensor]=None, dtype:Optional[int]=None) -> Tensor : ...
......@@ -8,6 +8,8 @@
# pylint: disable=missing-function-docstring
import copy
from math import inf
import tempfile
import unittest
......@@ -16,6 +18,7 @@ import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import fairscale.optim as optim
......@@ -447,3 +450,78 @@ def test_multiple_groups():
mp.spawn(
run_test_multiple_groups, args=(world_size, temp_file_name), nprocs=world_size, join=True,
)
def run_gradient_clipping(rank, world_size, tempfile_name):
dist_init(rank, world_size, tempfile_name, backend="gloo")
device = torch.device(rank)
torch.manual_seed(rank) # make sure that the different rank get different data
# Run a dummy step so that the optimizer state dict exists
batch, input_width, hidden, target_width = 3, 20, 10, 5
target = torch.rand((batch, target_width), device=device)
inputs = torch.rand((batch, input_width), device=device)
NORMS = [1.0, 2.0, 1, 2, inf]
CLIP_NORM = 0.3
def check(norm):
model_oss = torch.nn.Sequential(
torch.nn.Linear(input_width, hidden),
torch.nn.Linear(hidden, hidden),
torch.nn.Linear(hidden, target_width),
).to(device)
model = copy.deepcopy(model_oss)
# For this test the gradients are (all) reduced in the same way in between the torch reference and fairscale.
# Normally OSS would use ShardedDDP and only reduce to the proper rank, but this does not change the
# gradient norm computation from OSS and adds a dependency.
# to keep the comparison apples-to-apples DDP is used in both cases
model_oss = DDP(module=model_oss, device_ids=[rank],)
sharded_optimizer = optim.OSS(model_oss.parameters(), lr=0.1, momentum=0.99)
model = DDP(model, device_ids=[rank],)
loss_fn = torch.nn.L1Loss()
loss_fn.to(device)
model.zero_grad()
model_oss.zero_grad()
outputs = model(inputs)
outputs_oss = model_oss(inputs)
loss = loss_fn(outputs, target)
loss.backward()
loss_oss = loss_fn(outputs_oss, target)
loss_oss.backward()
# Check the equivalence with the non-sharded optim
oss_total_norm = sharded_optimizer.clip_grad_norm(CLIP_NORM, norm_type=norm)
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM, norm_type=norm)
assert torch.allclose(oss_total_norm, total_norm), "torch and fairscale should return the same grad norm"
# Check that the params have indeed been clipped
for params in sharded_optimizer.per_device_params.values():
for param in filter(lambda x: x.grad is not None, params[rank]):
assert torch.norm(param.grad, p=norm) < CLIP_NORM, f"param grad norm above clip : {param.grad}"
for norm in NORMS:
print(f"Checking norm {norm}")
check(norm)
dist.destroy_process_group()
@skip_if_no_cuda
def test_gradient_clipping():
world_size = 3
temp_file_name = tempfile.mkstemp()[1]
if torch.cuda.is_available():
world_size = min(world_size, torch.cuda.device_count())
reference_rank = 0
mp.spawn(
run_gradient_clipping, args=(world_size, temp_file_name), nprocs=world_size, join=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment