"...text-generation-inference.git" did not exist on "d14eaacacab9ca3056a9d001d0ca2dc0a36edfde"
Unverified Commit d85acf72 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] ShardedOptim: Distributed Grad Scaler (for torch AMP) (#182)

* adding a shard-aware GradScaler wrap, credits to Sean Naren for the idea
* adding stubs & explanations in the documentation
parent 587b707d
...@@ -14,6 +14,7 @@ from typing import Any, List, Optional, cast ...@@ -14,6 +14,7 @@ from typing import Any, List, Optional, cast
import numpy as np import numpy as np
import torch import torch
import torch.autograd.profiler as profiler import torch.autograd.profiler as profiler
from torch.cuda.amp import GradScaler as TorchGradScaler
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
...@@ -25,6 +26,7 @@ from torchvision.transforms import ToTensor ...@@ -25,6 +26,7 @@ from torchvision.transforms import ToTensor
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
OPTIM = torch.optim.RMSprop OPTIM = torch.optim.RMSprop
TEMPDIR = tempfile.gettempdir() TEMPDIR = tempfile.gettempdir()
...@@ -92,6 +94,7 @@ def train( ...@@ -92,6 +94,7 @@ def train(
# Shard the optimizer # Shard the optimizer
optimizer: Optional[torch.optim.Optimizer] = None optimizer: Optional[torch.optim.Optimizer] = None
model = cast(nn.Module, model) model = cast(nn.Module, model)
scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else ShardedGradScaler()) if args.amp else None
if optim_type == OptimType.oss_sharded_ddp: if optim_type == OptimType.oss_sharded_ddp:
model = ShardedDDP( model = ShardedDDP(
...@@ -102,7 +105,6 @@ def train( ...@@ -102,7 +105,6 @@ def train(
broadcast_buffers=True, broadcast_buffers=True,
) )
optimizer = model.sharded_optimizer optimizer = model.sharded_optimizer
else: else:
if args.cpu: if args.cpu:
device_ids = None device_ids = None
...@@ -136,7 +138,7 @@ def train( ...@@ -136,7 +138,7 @@ def train(
for batch in dataloader: for batch in dataloader:
batch__start = time.monotonic() batch__start = time.monotonic()
def closure(data=batch): def closure(data=batch, grad_scaler=None):
model.zero_grad() model.zero_grad()
if args.debug and rank == 0 and next(model.parameters()).grad is not None: if args.debug and rank == 0 and next(model.parameters()).grad is not None:
logging.debug( logging.debug(
...@@ -144,16 +146,18 @@ def train( ...@@ -144,16 +146,18 @@ def train(
next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item() next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
) )
) )
if not args.cpu and args.amp: if grad_scaler is not None:
# Automatically computes the FW pass in half precision # Automatically computes the FW pass in half precision
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
outputs = model(data["inputs"]) outputs = model(data["inputs"])
loss = loss_fn(outputs, data["label"]) loss = loss_fn(outputs, data["label"])
# Accumulates scaled gradients.
grad_scaler.scale(loss).backward()
else: else:
outputs = model(data["inputs"]) outputs = model(data["inputs"])
loss = loss_fn(outputs, data["label"]) loss = loss_fn(outputs, data["label"])
loss.backward()
loss.backward()
if optim_type == OptimType.oss_sharded_ddp: if optim_type == OptimType.oss_sharded_ddp:
model.reduce() model.reduce()
...@@ -170,16 +174,24 @@ def train( ...@@ -170,16 +174,24 @@ def train(
logging.info("Profiling the run") logging.info("Profiling the run")
with profiler.profile(use_cuda=True, record_shapes=True, profile_memory=True) as prof: # type: ignore with profiler.profile(use_cuda=True, record_shapes=True, profile_memory=True) as prof: # type: ignore
with profiler.record_function("batch"): with profiler.record_function("batch"):
final_loss = optimizer.step(closure) if scaler is not None:
logging.info("profiling done") final_loss = closure(grad_scaler=scaler) # AMP scaler.step does not support closures
scaler.step(optimizer)
scaler.update()
else:
final_loss = optimizer.step(closure)
if rank == 0: prof.export_chrome_trace(f"{optim_type}_trace_rank_{rank}.json")
prof.export_chrome_trace(f"{optim_type}_trace.json")
need_profiling = False # only profile once need_profiling = False # only profile once
else: else:
final_loss = optimizer.step(closure) if scaler is not None:
final_loss = closure(grad_scaler=scaler) # AMP scaler.step does not support closures
scaler.step(optimizer)
scaler.update()
else:
final_loss = optimizer.step(closure)
if args.debug and rank == 0: if args.debug and rank == 0:
logging.debug("buffer: {}".format(next(model.buffers()).norm().item())) logging.debug("buffer: {}".format(next(model.buffers()).norm().item()))
......
...@@ -6,4 +6,5 @@ API Reference ...@@ -6,4 +6,5 @@ API Reference
optim/adascale optim/adascale
optim/oss optim/oss
optim/grad_scaler
nn/pipe nn/pipe
Sharded Grad Scaler
========================
Enabling PyTorch's automatic mixed precision usually means using a `GradScaler` to detect underflows.
This grad scaler is not aware of the state sharding when Fairscale OSS is involved, and will lead to deadlocks.
Make sure that you use `ShardedGradScaler` in that case, which is a shard-aware wrapper of PyTorch's implementation.
.. code-block:: python
import torch
from fairscale.optim.oss import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
def train(
rank: int,
world_size: int,
epochs: int):
# DDP
dist_init(rank, world_size)
# Problem statement
model = myAwesomeModel().to(rank)
model = DDP(model, device_ids=[rank])
dataloader = mySuperFastDataloader()
loss_ln = myVeryRelevantLoss()
# optimizer specific arguments e.g. LR, momentum, etc...
base_optimizer_arguments = { "lr": 1e-4}
# ** NEW ** Wrap a base optimizer into OSS
base_optimizer = torch.optim.SGD # any pytorch compliant optimizer
optimizer = OSS(
params=model.parameters(),
optim=base_optimizer,
**base_optimizer_arguments)
scaler = ShardedGradScaler()
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
for e in range(epochs):
for (data, target) in dataloader:
data, target = data.to(rank), target.to(rank)
# Automatically computes the FW pass in half precision
with torch.cuda.amp.autocast():
model.zero_grad()
outputs = model(data)
loss = loss_fn(outputs, target)
# Automatically handle scaled gradients
scaler.scale(loss).backward()
optimizer.step()
...@@ -29,6 +29,7 @@ Components ...@@ -29,6 +29,7 @@ Components
* `tensor parallelism <../../build/html/api/nn/model_parallel.html>`_ * `tensor parallelism <../../build/html/api/nn/model_parallel.html>`_
* Optimization: * Optimization:
* `optimizer state sharding <../../build/html/api/optim/oss.html>`_ * `optimizer state sharding <../../build/html/api/optim/oss.html>`_
* `sharded grad scaler - AMP <../../build/html/api/optim/grad_scaler.html>`_
* `AdaScale SGD <../../build/html/api/optim/adascale.html>`_ * `AdaScale SGD <../../build/html/api/optim/adascale.html>`_
......
...@@ -86,7 +86,6 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer ...@@ -86,7 +86,6 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
model.zero_grad() model.zero_grad()
outputs = model(data) outputs = model(data)
loss = loss_fn(outputs, target) loss = loss_fn(outputs, target)
loss /= world_size
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -104,5 +103,45 @@ The above `train` function will then need to be run via a `multiprocessing.spawn ...@@ -104,5 +103,45 @@ The above `train` function will then need to be run via a `multiprocessing.spawn
) )
to see it in action, you can test it with the following script `here <../../../examples/tutorial_oss.py>`_. to see it in action, you can test it with the following script `here <../../../examples/tutorial_oss.py>`_.
Using PyTorch Automatic Mixed Precision is possible, but it requires a shard-aware GradScaler, which is available in
`fairscale.optim.grad_scaler`. Autocast can be used as is, and the loss will be scaled and handled in the same way.
See [the original documentation] (https://pytorch.org/docs/stable/notes/amp_examples.html?highlight=automatic%20mixed%20precision)
for more information.
.. code-block:: python
from fairscale.optim.grad_scaler import ShardedGradScaler
# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# Creates a ShardedGradScaler once at the beginning of training.
scaler = ShardedGradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
# Runs the forward pass with autocasting.
with autocast():
output = model(input)
loss = loss_fn(output, target)
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
# Backward passes under autocast are not recommended.
# Backward ops run in the same dtype autocast chose for corresponding forward ops.
scaler.scale(loss).backward()
# scaler.step() first unscales the gradients of the optimizer's assigned params.
# If these gradients do not contain infs or NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
...@@ -57,7 +57,8 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool): ...@@ -57,7 +57,8 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool):
training_start = time.monotonic() training_start = time.monotonic()
# Any relevant training loop, nothing specific to OSS. For example: # Any relevant training loop, nothing specific to OSS. For example:
model.train() model.train()
for e in range(epochs):
for _ in range(epochs):
for (data, target) in dataloader: for (data, target) in dataloader:
data, target = data.to(rank), target.to(rank) data, target = data.to(rank), target.to(rank)
......
...@@ -3,15 +3,44 @@ ...@@ -3,15 +3,44 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Dict from typing import Any, Dict, Optional
from torch import Tensor, device import torch
from torch.cuda.amp import GradScaler as TorchGradScaler from torch.cuda.amp import GradScaler as TorchGradScaler
import torch.distributed as dist
from torch.optim import Optimizer from torch.optim import Optimizer
from .oss import OSS
class GradScaler(TorchGradScaler): class GradScaler(TorchGradScaler):
def _unscale_grads_( def _unscale_grads_(
self, optimizer: Optimizer, inv_scale: Tensor, found_inf: Tensor, allow_fp16: bool self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
) -> Dict[device, Tensor]: ) -> Dict[torch.device, torch.Tensor]:
return super()._unscale_grads_(optimizer, inv_scale, found_inf, True) return super()._unscale_grads_(optimizer, inv_scale, found_inf, True)
class ShardedGradScaler(TorchGradScaler):
"""
A shard-aware :class:`GradScaler<torch.cuda.amp.GradScaler>`, to be used in conjunction with
:class:`OSS` and :class:`ShardedOptimizer`.
Interface and usecases are not changed, more explanations can be found in the corresponding pytorch
documentation https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
"""
def __init__(self) -> None:
super().__init__()
def step(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> Optional[float]:
assert isinstance(optimizer, OSS), "ShardedGradScaler is to be used in combination with a sharded optimizer"
# Re-use the GradSCaler machinery, but make sure that the status is sync'ed in between the ranks
optimizer_state = self._per_optimizer_states[id(optimizer)]
handles = [dist.all_reduce(v, async_op=True) for v in optimizer_state["found_inf_per_device"].values()]
# Make sure that the calls are done before moving out
_ = list(map(lambda x: x.wait(), handles))
# Call Torch's GradScaler in turn, states have been synchronized across ranks
return super().step(optimizer, *args, **kwargs)
...@@ -2,7 +2,13 @@ ...@@ -2,7 +2,13 @@
from ...optim import Optimizer from ...optim import Optimizer
from ... import device, Tensor from ... import device, Tensor
from typing import Dict from typing import Dict, Any, Optional
class GradScaler(object): class GradScaler(object):
_scale: Optional[Tensor]
_grows_tracker: Optional[Tensor]
_per_optimizer_states: Dict[int, Dict[str, Any]]
def _unscale_grads_(self, optimizer: Optimizer, inv_scale: Tensor, found_inf: Tensor, allow_fp16: bool) -> Dict[device, Tensor]:... def _unscale_grads_(self, optimizer: Optimizer, inv_scale: Tensor, found_inf: Tensor, allow_fp16: bool) -> Dict[device, Tensor]:...
def step(self, optimizer: Optimizer, *args: Any, **kwargs: Any): ...
def update(self, new_scale: Optional[float]=None): ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment