Unverified Commit ade312c4 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] OSS-aware clip grads, bridge sharded states (#167)

add a clip gradients util, equivalent to torch's but aware of the sharded states. Add a corresponding unit test
parent 2fe93203
...@@ -136,7 +136,7 @@ def train( ...@@ -136,7 +136,7 @@ def train(
for batch in dataloader: for batch in dataloader:
batch__start = time.monotonic() batch__start = time.monotonic()
def closure(): def closure(data=batch):
model.zero_grad() model.zero_grad()
if args.debug and rank == 0 and next(model.parameters()).grad is not None: if args.debug and rank == 0 and next(model.parameters()).grad is not None:
logging.debug( logging.debug(
...@@ -147,11 +147,11 @@ def train( ...@@ -147,11 +147,11 @@ def train(
if not args.cpu and args.amp: if not args.cpu and args.amp:
# Automatically computes the FW pass in half precision # Automatically computes the FW pass in half precision
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
outputs = model(batch["inputs"]) outputs = model(data["inputs"])
loss = loss_fn(outputs, batch["label"]) loss = loss_fn(outputs, data["label"])
else: else:
outputs = model(batch["inputs"]) outputs = model(data["inputs"])
loss = loss_fn(outputs, batch["label"]) loss = loss_fn(outputs, data["label"])
loss.backward() loss.backward()
...@@ -257,9 +257,9 @@ if __name__ == "__main__": ...@@ -257,9 +257,9 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG) logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
logging.info(f"Benchmark arguments: {args}") logging.info("Benchmark arguments: %s" % args)
backend = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo" BACKEND = "nccl" if (not args.gloo or not torch.cuda.is_available()) and not args.cpu else "gloo"
# Download dataset once for all processes # Download dataset once for all processes
dataset, tentatives = None, 0 dataset, tentatives = None, 0
...@@ -271,7 +271,7 @@ if __name__ == "__main__": ...@@ -271,7 +271,7 @@ if __name__ == "__main__":
# Corrupted data, erase and restart # Corrupted data, erase and restart
shutil.rmtree(TEMPDIR + "/MNIST") shutil.rmtree(TEMPDIR + "/MNIST")
logging.warning("Failed loading dataset: ", e) logging.warning("Failed loading dataset: %s " % e)
tentatives += 1 tentatives += 1
if dataset is None: if dataset is None:
...@@ -285,7 +285,7 @@ if __name__ == "__main__": ...@@ -285,7 +285,7 @@ if __name__ == "__main__":
logging.info("\n*** Benchmark vanilla optimizer") logging.info("\n*** Benchmark vanilla optimizer")
mp.spawn( mp.spawn(
train, train,
args=(args, backend, OptimType.vanilla, False,), # no regression check args=(args, BACKEND, OptimType.vanilla, False,), # no regression check
nprocs=args.world_size, nprocs=args.world_size,
join=True, join=True,
) )
...@@ -293,7 +293,7 @@ if __name__ == "__main__": ...@@ -293,7 +293,7 @@ if __name__ == "__main__":
if args.optim_type == OptimType.oss_ddp or args.optim_type == OptimType.everyone: if args.optim_type == OptimType.oss_ddp or args.optim_type == OptimType.everyone:
logging.info("\n*** Benchmark OSS with DDP") logging.info("\n*** Benchmark OSS with DDP")
mp.spawn( mp.spawn(
train, args=(args, backend, OptimType.oss_ddp, args.check_regression), nprocs=args.world_size, join=True, train, args=(args, BACKEND, OptimType.oss_ddp, args.check_regression), nprocs=args.world_size, join=True,
) )
if args.optim_type == OptimType.oss_sharded_ddp or args.optim_type == OptimType.everyone: if args.optim_type == OptimType.oss_sharded_ddp or args.optim_type == OptimType.everyone:
...@@ -302,7 +302,7 @@ if __name__ == "__main__": ...@@ -302,7 +302,7 @@ if __name__ == "__main__":
train, train,
args=( args=(
args, args,
backend, BACKEND,
OptimType.oss_sharded_ddp, OptimType.oss_sharded_ddp,
False, False,
), # FIXME: @lefaudeux - SDP should give the same results ), # FIXME: @lefaudeux - SDP should give the same results
......
...@@ -84,12 +84,7 @@ def train(rank, args, model, device, train_loader, num_epochs): ...@@ -84,12 +84,7 @@ def train(rank, args, model, device, train_loader, num_epochs):
model.zero_grad() model.zero_grad()
outputs = model(data) outputs = model(data)
loss = loss_fn(outputs, target) loss = loss_fn(outputs, target)
loss /= WORLD_SIZE
loss.backward() loss.backward()
# if dist.get_rank() == 0:
# print(f"Loss: {loss.item()}")
ddp.reduce() # Send the gradients to the appropriate shards ddp.reduce() # Send the gradients to the appropriate shards
return loss return loss
......
import time import time
from typing import Optional, Union, cast
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -44,12 +45,14 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool): ...@@ -44,12 +45,14 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool):
dataloader = getData() dataloader = getData()
loss_fn = getLossFun() loss_fn = getLossFun()
base_optimizer_arguments = {"lr": 1e-4} # any optimizer specific arguments, LR, momentum, etc... optimizer: Optional[Union[OSS, torch.optim.SGD]] = None
if ~use_oss:
optimizer = torch.optim.SGD(params=model.parameters(), **base_optimizer_arguments) if not use_oss:
optimizer = torch.optim.SGD(params=model.parameters(), lr=1e-4)
else: else:
base_optimizer = torch.optim.SGD base_optimizer = torch.optim.SGD
optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments) base_optimizer_arguments = {"lr": 1e-4} # any optimizer specific arguments, LR, momentum, etc...
optimizer = OSS(params=model.parameters(), optim=base_optimizer, default=base_optimizer_arguments)
training_start = time.monotonic() training_start = time.monotonic()
# Any relevant training loop, nothing specific to OSS. For example: # Any relevant training loop, nothing specific to OSS. For example:
...@@ -57,13 +60,24 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool): ...@@ -57,13 +60,24 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool):
for e in range(epochs): for e in range(epochs):
for (data, target) in dataloader: for (data, target) in dataloader:
data, target = data.to(rank), target.to(rank) data, target = data.to(rank), target.to(rank)
# Train # Train
model.zero_grad() model.zero_grad()
outputs = model(data) outputs = model(data)
loss = loss_fn(outputs, target) loss = loss_fn(outputs, target)
loss /= world_size
loss.backward() loss.backward()
# if you want to clip the gradients / get the current max:
max_norm = 1000.0
norm_type = 1
if not use_oss:
_total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm, norm_type=norm_type) # type: ignore
else:
optimizer = cast(OSS, optimizer)
_total_norm = optimizer.clip_grad_norm(max_norm, norm_type=norm_type)
optimizer.step() optimizer.step()
print(f"Loss: {loss.item()}") print(f"Loss: {loss.item()}")
training_end = time.monotonic() training_end = time.monotonic()
......
...@@ -5,9 +5,11 @@ ...@@ -5,9 +5,11 @@
from collections import OrderedDict from collections import OrderedDict
import copy import copy
import itertools
from itertools import chain from itertools import chain
import logging import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type from math import inf
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -72,9 +74,7 @@ class OSS(Optimizer): ...@@ -72,9 +74,7 @@ class OSS(Optimizer):
self.in_super_constructor = False self.in_super_constructor = False
# Partition information. lazy evaluation, computed if requested # Partition information. lazy evaluation, computed if requested
self._per_device_params: OrderedDict[ self._per_device_params: Dict[torch.device, List[List[Parameter]]] = OrderedDict() # device, rank, params
torch.device, List[List[Parameter]]
] = OrderedDict() # device, rank, params
self._param_rank: Dict[torch.Tensor, int] = {} self._param_rank: Dict[torch.Tensor, int] = {}
self._partition_parameters: List[List[dict]] = [] self._partition_parameters: List[List[dict]] = []
...@@ -201,6 +201,70 @@ class OSS(Optimizer): ...@@ -201,6 +201,70 @@ class OSS(Optimizer):
return loss return loss
def clip_grad_norm(self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0) -> torch.Tensor:
"""
Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
Arguments:
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
.. note: This is analogous to `torch.nn.utils.clip_grad_norm_` but handles the partitioning and multiple devices per rank
under the hood. The default torch util is not applicable here, because each rank only has a partial view of all the grads
in the model, so calling it in the OSS context would lead to different scaling being applied per subset of model parameters
.. warning: This needs to be called on all ranks, since synchronization primitives will be used
.. warning: Model paralelism -groups other than world- are not yet supported
"""
if self.group != dist.group.WORLD:
raise NotImplementedError("Clip norm not yet supported for model parallelism (coming soon!)")
# Compute the max norm for this shards's worth of gradients
max_norm = float(max_norm)
norm_type = float(norm_type)
# Filter out the grad-less params, concatenate params from all devices
local_params = itertools.chain(
*[
list(filter(lambda x: x.grad is not None, device_params[self.rank]))
for device_params in self.per_device_params.values()
]
)
# Compute the norm on this grad set,
# then sync all the norms from all ranks
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(self._device) for p in local_params) # type: ignore
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.group)
else:
local_norm = torch.norm(
input=torch.stack([torch.norm(input=p.grad.detach(), p=norm_type).to(self._device) for p in local_params]), # type: ignore
p=norm_type,
)
# local norm result can be accumulated with the remote ones if put to the right power
# n_i = sum_rank(a^p)^1/p
# -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p
total_norm = local_norm ** norm_type
dist.all_reduce(total_norm, group=self.group)
total_norm = total_norm ** (1.0 / norm_type)
clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)
if clip_coef < 1:
for device, device_params in self.per_device_params.items():
for p in filter(lambda x: x.grad is not None, device_params[self.rank]):
p.grad.detach().mul_(clip_coef.to(device)) # type: ignore
return total_norm
# State dict interfaces
def local_state_dict(self) -> dict: def local_state_dict(self) -> dict:
"""Gets this rank's state_dict. """Gets this rank's state_dict.
...@@ -345,6 +409,14 @@ class OSS(Optimizer): ...@@ -345,6 +409,14 @@ class OSS(Optimizer):
if len(param_groups) == len(self.optim.param_groups) + 1: if len(param_groups) == len(self.optim.param_groups) + 1:
self.optim.add_param_group(param_groups[-1]) self.optim.add_param_group(param_groups[-1])
@staticmethod
def get_global_rank(group: Any, rank: int) -> int:
if group is dist.group.WORLD:
return rank
else:
global_rank = dist.distributed_c10d._get_global_rank(group, rank)
return global_rank
def _sync_param_groups(self, local_to_global: bool = False) -> None: def _sync_param_groups(self, local_to_global: bool = False) -> None:
"""Sync learning rate and other optimizer attributes (needed to support schedulers). """Sync learning rate and other optimizer attributes (needed to support schedulers).
If the global param groups have been altered, and we want to make sure that the If the global param groups have been altered, and we want to make sure that the
...@@ -417,14 +489,6 @@ class OSS(Optimizer): ...@@ -417,14 +489,6 @@ class OSS(Optimizer):
for t in p["params"]: for t in p["params"]:
t.grad = None t.grad = None
@staticmethod
def get_global_rank(group: Any, rank: int) -> int:
if group is dist.group.WORLD:
return rank
else:
global_rank = dist.distributed_c10d._get_global_rank(group, rank)
return global_rank
def _broadcast_params(self, buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]]) -> None: def _broadcast_params(self, buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]]) -> None:
"""Helper function to broadcast all the parameters from a given device""" """Helper function to broadcast all the parameters from a given device"""
buffer_size = buffers[0].numel() buffer_size = buffers[0].numel()
......
...@@ -600,6 +600,7 @@ class Tensor: ...@@ -600,6 +600,7 @@ class Tensor:
def mode(self, dim: _int=-1, keepdim: _bool=False) -> Tuple[Tensor, Tensor]: ... def mode(self, dim: _int=-1, keepdim: _bool=False) -> Tuple[Tensor, Tensor]: ...
@overload @overload
def mode(self, dim: Union[str, None], keepdim: _bool=False) -> Tuple[Tensor, Tensor]: ... def mode(self, dim: Union[str, None], keepdim: _bool=False) -> Tuple[Tensor, Tensor]: ...
def mul_(self, value: Union[_float, _int, Tensor]): ...
def multinomial(self, num_samples: _int, replacement: _bool=False, *, generator: Generator=None) -> Tensor: ... def multinomial(self, num_samples: _int, replacement: _bool=False, *, generator: Generator=None) -> Tensor: ...
def mv(self, vec: Tensor) -> Tensor: ... def mv(self, vec: Tensor) -> Tensor: ...
def mvlgamma(self, p: _int) -> Tensor: ... def mvlgamma(self, p: _int) -> Tensor: ...
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from . import Tensor from . import Tensor
from typing import Tuple, List, Union from typing import Tuple, List, Union, Optional, Any
def split(tensor: Tensor, split_size_or_sections: Union[int, List[int]], dim: int=0) -> Tuple[Tensor,...]: ... def split(tensor: Tensor, split_size_or_sections: Union[int, List[int]], dim: int=0) -> Tuple[Tensor,...]: ...
def einsum(equation: str, *operands: Tensor): ... def einsum(equation: str, *operands: Tensor): ...
def norm(input: Tensor, p: Union[int, float, Any], dim: Optional[List[int]]=None, keep_dim: Optional[bool]=False, out: Optional[Tensor]=None, dtype:Optional[int]=None) -> Tensor : ...
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
import copy
from math import inf
import tempfile import tempfile
import unittest import unittest
...@@ -16,6 +18,7 @@ import pytest ...@@ -16,6 +18,7 @@ import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import fairscale.optim as optim import fairscale.optim as optim
...@@ -447,3 +450,78 @@ def test_multiple_groups(): ...@@ -447,3 +450,78 @@ def test_multiple_groups():
mp.spawn( mp.spawn(
run_test_multiple_groups, args=(world_size, temp_file_name), nprocs=world_size, join=True, run_test_multiple_groups, args=(world_size, temp_file_name), nprocs=world_size, join=True,
) )
def run_gradient_clipping(rank, world_size, tempfile_name):
dist_init(rank, world_size, tempfile_name, backend="gloo")
device = torch.device(rank)
torch.manual_seed(rank) # make sure that the different rank get different data
# Run a dummy step so that the optimizer state dict exists
batch, input_width, hidden, target_width = 3, 20, 10, 5
target = torch.rand((batch, target_width), device=device)
inputs = torch.rand((batch, input_width), device=device)
NORMS = [1.0, 2.0, 1, 2, inf]
CLIP_NORM = 0.3
def check(norm):
model_oss = torch.nn.Sequential(
torch.nn.Linear(input_width, hidden),
torch.nn.Linear(hidden, hidden),
torch.nn.Linear(hidden, target_width),
).to(device)
model = copy.deepcopy(model_oss)
# For this test the gradients are (all) reduced in the same way in between the torch reference and fairscale.
# Normally OSS would use ShardedDDP and only reduce to the proper rank, but this does not change the
# gradient norm computation from OSS and adds a dependency.
# to keep the comparison apples-to-apples DDP is used in both cases
model_oss = DDP(module=model_oss, device_ids=[rank],)
sharded_optimizer = optim.OSS(model_oss.parameters(), lr=0.1, momentum=0.99)
model = DDP(model, device_ids=[rank],)
loss_fn = torch.nn.L1Loss()
loss_fn.to(device)
model.zero_grad()
model_oss.zero_grad()
outputs = model(inputs)
outputs_oss = model_oss(inputs)
loss = loss_fn(outputs, target)
loss.backward()
loss_oss = loss_fn(outputs_oss, target)
loss_oss.backward()
# Check the equivalence with the non-sharded optim
oss_total_norm = sharded_optimizer.clip_grad_norm(CLIP_NORM, norm_type=norm)
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM, norm_type=norm)
assert torch.allclose(oss_total_norm, total_norm), "torch and fairscale should return the same grad norm"
# Check that the params have indeed been clipped
for params in sharded_optimizer.per_device_params.values():
for param in filter(lambda x: x.grad is not None, params[rank]):
assert torch.norm(param.grad, p=norm) < CLIP_NORM, f"param grad norm above clip : {param.grad}"
for norm in NORMS:
print(f"Checking norm {norm}")
check(norm)
dist.destroy_process_group()
@skip_if_no_cuda
def test_gradient_clipping():
world_size = 3
temp_file_name = tempfile.mkstemp()[1]
if torch.cuda.is_available():
world_size = min(world_size, torch.cuda.device_count())
reference_rank = 0
mp.spawn(
run_gradient_clipping, args=(world_size, temp_file_name), nprocs=world_size, join=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment