Unverified Commit d85acf72 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] ShardedOptim: Distributed Grad Scaler (for torch AMP) (#182)

* adding a shard-aware GradScaler wrap, credits to Sean Naren for the idea
* adding stubs & explanations in the documentation
parent 587b707d
......@@ -14,6 +14,7 @@ from typing import Any, List, Optional, cast
import numpy as np
import torch
import torch.autograd.profiler as profiler
from torch.cuda.amp import GradScaler as TorchGradScaler
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
......@@ -25,6 +26,7 @@ from torchvision.transforms import ToTensor
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
OPTIM = torch.optim.RMSprop
TEMPDIR = tempfile.gettempdir()
......@@ -92,6 +94,7 @@ def train(
# Shard the optimizer
optimizer: Optional[torch.optim.Optimizer] = None
model = cast(nn.Module, model)
scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else ShardedGradScaler()) if args.amp else None
if optim_type == OptimType.oss_sharded_ddp:
model = ShardedDDP(
......@@ -102,7 +105,6 @@ def train(
broadcast_buffers=True,
)
optimizer = model.sharded_optimizer
else:
if args.cpu:
device_ids = None
......@@ -136,7 +138,7 @@ def train(
for batch in dataloader:
batch__start = time.monotonic()
def closure(data=batch):
def closure(data=batch, grad_scaler=None):
model.zero_grad()
if args.debug and rank == 0 and next(model.parameters()).grad is not None:
logging.debug(
......@@ -144,16 +146,18 @@ def train(
next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
)
)
if not args.cpu and args.amp:
if grad_scaler is not None:
# Automatically computes the FW pass in half precision
with torch.cuda.amp.autocast():
outputs = model(data["inputs"])
loss = loss_fn(outputs, data["label"])
# Accumulates scaled gradients.
grad_scaler.scale(loss).backward()
else:
outputs = model(data["inputs"])
loss = loss_fn(outputs, data["label"])
loss.backward()
loss.backward()
if optim_type == OptimType.oss_sharded_ddp:
model.reduce()
......@@ -170,16 +174,24 @@ def train(
logging.info("Profiling the run")
with profiler.profile(use_cuda=True, record_shapes=True, profile_memory=True) as prof: # type: ignore
with profiler.record_function("batch"):
final_loss = optimizer.step(closure)
logging.info("profiling done")
if scaler is not None:
final_loss = closure(grad_scaler=scaler) # AMP scaler.step does not support closures
scaler.step(optimizer)
scaler.update()
else:
final_loss = optimizer.step(closure)
if rank == 0:
prof.export_chrome_trace(f"{optim_type}_trace.json")
prof.export_chrome_trace(f"{optim_type}_trace_rank_{rank}.json")
need_profiling = False # only profile once
else:
final_loss = optimizer.step(closure)
if scaler is not None:
final_loss = closure(grad_scaler=scaler) # AMP scaler.step does not support closures
scaler.step(optimizer)
scaler.update()
else:
final_loss = optimizer.step(closure)
if args.debug and rank == 0:
logging.debug("buffer: {}".format(next(model.buffers()).norm().item()))
......
......@@ -6,4 +6,5 @@ API Reference
optim/adascale
optim/oss
optim/grad_scaler
nn/pipe
Sharded Grad Scaler
========================
Enabling PyTorch's automatic mixed precision usually means using a `GradScaler` to detect underflows.
This grad scaler is not aware of the state sharding when Fairscale OSS is involved, and will lead to deadlocks.
Make sure that you use `ShardedGradScaler` in that case, which is a shard-aware wrapper of PyTorch's implementation.
.. code-block:: python
import torch
from fairscale.optim.oss import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
def train(
rank: int,
world_size: int,
epochs: int):
# DDP
dist_init(rank, world_size)
# Problem statement
model = myAwesomeModel().to(rank)
model = DDP(model, device_ids=[rank])
dataloader = mySuperFastDataloader()
loss_ln = myVeryRelevantLoss()
# optimizer specific arguments e.g. LR, momentum, etc...
base_optimizer_arguments = { "lr": 1e-4}
# ** NEW ** Wrap a base optimizer into OSS
base_optimizer = torch.optim.SGD # any pytorch compliant optimizer
optimizer = OSS(
params=model.parameters(),
optim=base_optimizer,
**base_optimizer_arguments)
scaler = ShardedGradScaler()
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
for e in range(epochs):
for (data, target) in dataloader:
data, target = data.to(rank), target.to(rank)
# Automatically computes the FW pass in half precision
with torch.cuda.amp.autocast():
model.zero_grad()
outputs = model(data)
loss = loss_fn(outputs, target)
# Automatically handle scaled gradients
scaler.scale(loss).backward()
optimizer.step()
......@@ -29,6 +29,7 @@ Components
* `tensor parallelism <../../build/html/api/nn/model_parallel.html>`_
* Optimization:
* `optimizer state sharding <../../build/html/api/optim/oss.html>`_
* `sharded grad scaler - AMP <../../build/html/api/optim/grad_scaler.html>`_
* `AdaScale SGD <../../build/html/api/optim/adascale.html>`_
......
......@@ -86,7 +86,6 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
model.zero_grad()
outputs = model(data)
loss = loss_fn(outputs, target)
loss /= world_size
loss.backward()
optimizer.step()
......@@ -104,5 +103,45 @@ The above `train` function will then need to be run via a `multiprocessing.spawn
)
to see it in action, you can test it with the following script `here <../../../examples/tutorial_oss.py>`_.
to see it in action, you can test it with the following script `here <../../../examples/tutorial_oss.py>`_.
Using PyTorch Automatic Mixed Precision is possible, but it requires a shard-aware GradScaler, which is available in
`fairscale.optim.grad_scaler`. Autocast can be used as is, and the loss will be scaled and handled in the same way.
See [the original documentation] (https://pytorch.org/docs/stable/notes/amp_examples.html?highlight=automatic%20mixed%20precision)
for more information.
.. code-block:: python
from fairscale.optim.grad_scaler import ShardedGradScaler
# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# Creates a ShardedGradScaler once at the beginning of training.
scaler = ShardedGradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
# Runs the forward pass with autocasting.
with autocast():
output = model(input)
loss = loss_fn(output, target)
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
# Backward passes under autocast are not recommended.
# Backward ops run in the same dtype autocast chose for corresponding forward ops.
scaler.scale(loss).backward()
# scaler.step() first unscales the gradients of the optimizer's assigned params.
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
......@@ -57,7 +57,8 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool):
training_start = time.monotonic()
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
for e in range(epochs):
for _ in range(epochs):
for (data, target) in dataloader:
data, target = data.to(rank), target.to(rank)
......
......@@ -3,15 +3,44 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict
from typing import Any, Dict, Optional
from torch import Tensor, device
import torch
from torch.cuda.amp import GradScaler as TorchGradScaler
import torch.distributed as dist
from torch.optim import Optimizer
from .oss import OSS
class GradScaler(TorchGradScaler):
def _unscale_grads_(
self, optimizer: Optimizer, inv_scale: Tensor, found_inf: Tensor, allow_fp16: bool
) -> Dict[device, Tensor]:
self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
) -> Dict[torch.device, torch.Tensor]:
return super()._unscale_grads_(optimizer, inv_scale, found_inf, True)
class ShardedGradScaler(TorchGradScaler):
"""
A shard-aware :class:`GradScaler<torch.cuda.amp.GradScaler>`, to be used in conjunction with
:class:`OSS` and :class:`ShardedOptimizer`.
Interface and usecases are not changed, more explanations can be found in the corresponding pytorch
documentation https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
"""
def __init__(self) -> None:
super().__init__()
def step(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> Optional[float]:
assert isinstance(optimizer, OSS), "ShardedGradScaler is to be used in combination with a sharded optimizer"
# Re-use the GradSCaler machinery, but make sure that the status is sync'ed in between the ranks
optimizer_state = self._per_optimizer_states[id(optimizer)]
handles = [dist.all_reduce(v, async_op=True) for v in optimizer_state["found_inf_per_device"].values()]
# Make sure that the calls are done before moving out
_ = list(map(lambda x: x.wait(), handles))
# Call Torch's GradScaler in turn, states have been synchronized across ranks
return super().step(optimizer, *args, **kwargs)
......@@ -2,7 +2,13 @@
from ...optim import Optimizer
from ... import device, Tensor
from typing import Dict
from typing import Dict, Any, Optional
class GradScaler(object):
_scale: Optional[Tensor]
_grows_tracker: Optional[Tensor]
_per_optimizer_states: Dict[int, Dict[str, Any]]
def _unscale_grads_(self, optimizer: Optimizer, inv_scale: Tensor, found_inf: Tensor, allow_fp16: bool) -> Dict[device, Tensor]:...
def step(self, optimizer: Optimizer, *args: Any, **kwargs: Any): ...
def update(self, new_scale: Optional[float]=None): ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment