Unverified Commit 52d066a2 authored by Anupam Bhatnagar's avatar Anupam Bhatnagar Committed by GitHub
Browse files

[feature] [experimental] Layerwise Gradient Scaler (#879)

* [skip ci] first commit

* [skip ci] gradient scaler example

* [skip ci] adding feed forward toy example

* [skip ci] adding types

* [skip ci] adding backward hook

* [skip ci] update

* [skip ci] working feed forward example

* [skip ci] working feed forward example

* [skip ci] use named_modules instead of named_children

* [skip ci] adding new file

* [skip ci] clean up

* [skip ci] implement unscale function

* [skip ci] implement unscale function

* [skip ci] removing old file

* [skip ci] removing some more old files

* [skip ci] making unscale function generic

* [skip ci] adding test for vision model

* [skip ci] adding identity layer

* [skip ci] cleanup files

* [skip ci] refactoring

* [skip ci] more refactoring

* [skip ci] added functionality to update scale

* [skip ci] data loader clean up

* [skip ci] implemented inf checks and update scale functions

* [skip ci]code clean up. added...
parent fb4eca19
...@@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [0.4.5] - TBD ## [0.4.5] - TBD
### Added ### Added
- Layer-wise Gradient Scaling [new feature][experimental] Layer-wise gradient
scaling helps overcomes gradient overflow issues. When used in conjunction with
mixed precision, it enables training larger models and makes the training
process more stable, especially in deep networks [#879]
- FSDP: Added state_dict_on_rank_0_only flag allow user choose to return full
state dict on rank 0 and return empty dict non-rank 0 to prevent OOM [#844]
- FSDP: Added process_group_reduce_scatter parameter to allow users to pass in the process group that is used for reduce scatter operation. [#897] - FSDP: Added process_group_reduce_scatter parameter to allow users to pass in the process group that is used for reduce scatter operation. [#897]
- FSDP: Added state_dict_on_rank_0_only flag allow user choose to return full state dict on rank 0 and return empty dict non-rank 0 to prevent OOM [#844] - FSDP: Added state_dict_on_rank_0_only flag allow user choose to return full state dict on rank 0 and return empty dict non-rank 0 to prevent OOM [#844]
......
...@@ -11,8 +11,8 @@ import warnings ...@@ -11,8 +11,8 @@ import warnings
import torch import torch
from torch.cuda import FloatTensor # type: ignore from torch.cuda import FloatTensor # type: ignore
from torch.cuda.amp import GradScaler as TorchGradScaler
from torch.cuda.amp.common import amp_definitely_not_available from torch.cuda.amp.common import amp_definitely_not_available
from torch.cuda.amp.grad_scaler import GradScaler as TorchGradScaler
import torch.distributed as dist import torch.distributed as dist
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.sgd import SGD from torch.optim.sgd import SGD
......
import logging
from typing import List, Tuple
import torch
import torch.nn as nn
class LayerInfo:
"""
A class to record the layer attributes.
"""
def __init__(self, name: str, layer: nn.Module, scale: float = 1.0, scale_layer: bool = False) -> None:
"""
layer_name: name of the layer e.g. fc1, conv1, relu1
layer: type of the layer e.g. Linear, Conv2d, ReLU
scaling_factor: user configurable scaling factor for the layer, defaults to 1.0
found_inf_or_nan: a boolean indicating if any parameter of layer's gradient contains inf/nan
growth_tracker: tracks number of step since last time scale was increased
scale_layer: a boolean indicating if the layer should be scaled or not
"""
self.layer_name = name
self.layer = layer
self.scaling_factor = scale
self.found_inf_or_nan = False
self.growth_tracker = 0
self.scale_layer = scale_layer
class GradientHelper:
"""
A helper class to create instances of backward hooks. The hooks are registered in the
scale method of LayerwiseGradientScaler.
"""
def __init__(self, name: str, inputs_multiplier: float, outputs_multiplier: float):
self.layer_name = name
self.inputs_multiplier = inputs_multiplier
self.outputs_multiplier = outputs_multiplier
def scale_gradients(self, m: nn.Module, inputs: Tuple, outputs: Tuple) -> Tuple[torch.Tensor]:
"""
Backward hook that is attached to the layers to scale the gradients.
"""
scaled_up_grads = list()
for idx in range(len(inputs)):
if inputs[idx] is not None:
if self.inputs_multiplier != 1.0 or self.outputs_multiplier != 1.0:
logging.debug(
"layer = %s \t scale = %s \t scale_down = %s"
% (self.layer_name, self.inputs_multiplier, self.outputs_multiplier)
)
scaled_up_grads.append(inputs[idx].mul(self.inputs_multiplier * self.outputs_multiplier))
else:
logging.debug("next layer is None")
scaled_up_grads.append(inputs[idx])
return tuple(scaled_up_grads) # type: ignore
class LayerwiseGradientScaler:
"""
LayerwiseGradientScaler enables using distinct scaling factors for each layer
of the network.
Example:
# Create a convolutional network
class ConvNet(nn.Module):
def __init__(self):
...
def forward(self, x):
...
# Create an instance of the model
model = ConvNet()
optimizer = torch.optim.SGD(model.parameters())
# specify the layers to scale and their scaling factor
layer_scale_dict = {"conv1": 2**10, "conv2": 2**8, "fc1": 2**10, "fc2": 2**9}
scaler = LayerwiseGradientScaler(model, layer_scale_dict)
for epoch in num_epochs:
for inputs, targets in batch:
optimizer.zero_grad()
# scale the gradients
scaler.scale()
# enables mixed precision training
with autocast():
predictions = model(inputs)
loss = loss_function(predictions, targets)
loss.backward()
# unscale the gradients
loss.unscale()
# step is taken if there are no inf/nan in the gradients
# scaling factor for each layer are updated
loss.step(optimizer)
Args:
model : instance of a Model class, such as ConvNet above
layer_scale_dict (dict) : dictionary with key = layer_name and value = scaling_factor
growth_factor (float) : per layer scaling factor multiplier
backoff_factor (float) : per layer scaling factor multiplier when an inf/nan is found
growth_interval (int) : number of steps after which scale is multiplied by growth_factor
min_scaling_factor (float) : smallest scaling factor
max_scaling_factor (float) : largest scaling factor
"""
def __init__( # type: ignore
self,
model,
layer_scale_dict: dict,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 10000,
min_scale: float = torch.finfo(torch.float32).tiny, # type: ignore
max_scale: float = torch.finfo(torch.float32).max, # type: ignore
) -> None:
self._model = model
self._layer_scale_dict: dict = layer_scale_dict
self._growth_factor: float = growth_factor
self._backoff_factor: float = backoff_factor
self._growth_interval: int = growth_interval
self._apply_layerwise_scaling: bool = True if len(layer_scale_dict.keys()) > 0 else False
self._min_scale = min_scale
self._max_scale = max_scale
self._handles: List = []
self.layer_info: List = []
if self._apply_layerwise_scaling:
assert self._growth_factor > 1.0, "The growth factor must be > 1.0."
assert self._backoff_factor < 1.0, "The backoff factor must be < 1.0."
self.layer_info = self._build_layer_info()
def _build_layer_info(self) -> List:
"""
Helper function to create a list of LayerInfo instances.
"""
layer_info_list = list()
for name, layer in self._model.named_modules():
if name != "":
if name not in self._layer_scale_dict.keys():
logging.debug("name = %s, layer = %s, scaling_factor = %s" % (name, layer, 1.0))
layer_info_list.append(LayerInfo(name, layer, 1.0))
else:
logging.debug(
"name = %s, layer = %s, scaling_factor = %s" % (name, layer, self._layer_scale_dict[name])
)
layer_info_list.append(LayerInfo(name, layer, self._layer_scale_dict[name], True))
return layer_info_list
def scale(self) -> None:
"""
For each layer calculates the scaling factor for preceding layers' grad inputs
and current layers' grad outputs. These values are used to register a full backward
hook. The handle returned from registering the backward hook is appended to a list
of handles. New hooks are created and registered at every step and a new list of
handles is created. The handles are flushed out in the unscale function.
"""
if not self._apply_layerwise_scaling:
return
for idx in range(len(self.layer_info)):
elt = self.layer_info[idx]
layer_name, layer = elt.layer_name, elt.layer
inputs_multiplier = 1.0
if idx > 0:
inputs_multiplier = self.layer_info[idx - 1].scaling_factor
outputs_multiplier = 1.0 / elt.scaling_factor
helper = GradientHelper(layer_name, inputs_multiplier, outputs_multiplier)
layer_handle = layer.register_full_backward_hook(helper.scale_gradients)
self._handles.append(layer_handle)
logging.debug("name = %s \t scale = %s" % (layer_name, elt.scaling_factor))
def _get_layers_with_finite_values(self) -> List[LayerInfo]:
layers_with_finite_values: List = []
for item in self.layer_info:
if not item.found_inf_or_nan:
layers_with_finite_values.append(item)
return layers_with_finite_values
def unscale(self) -> None:
"""
For each layer, check if any of the layers' parameters contain an inf/nan.
If there are no inf/nan in the gradient, then gradient of that layer is
unscaled by the reciprocal of the scaling factor for that layer.
Finally, all handles recorded while registering the hooks are deleted.
"""
if not self._apply_layerwise_scaling:
return
layers_with_finite_values = self._get_layers_with_finite_values()
for item in layers_with_finite_values:
for param_name, param in item.layer.named_parameters():
if hasattr(param, "grad"):
logging.debug("%s scaling down %s by %s" % (item.layer_name, param_name, 1.0 / item.scaling_factor))
param.grad.mul_(1.0 / item.scaling_factor)
while len(self._handles) > 0:
elt = self._handles.pop()
elt.remove()
def _check_for_inf_or_nan(self) -> None:
"""
For each layer, check if any of the parameters with a gradient attribute
contain an inf/nan. If any of the parameters' gradient contain an inf/nan,
then that layers' found_inf_or_nan attribute is set to True and all
remaining parameters for that layer are skipped.
"""
for elt in self.layer_info:
elt.found_inf_or_nan = False
for _, param in elt.layer.named_parameters():
if hasattr(param, "grad") and param.grad is not None:
if torch.isinf(param.grad).any().item() or torch.isnan(param.grad).any().item(): # type: ignore
elt.found_inf_or_nan = True
break # skip all remaining named parameters
def step(self, optimizer) -> None: # type: ignore
"""
If there are no inf/nan in the gradients' of all layers, then optimizer
takes a step, otherwise not. Update the scaling factor for each layer.
"""
# using layerwise gradient scaling
if self._apply_layerwise_scaling:
self._check_for_inf_or_nan()
inf_nan_found = any(elt.found_inf_or_nan for elt in self.layer_info)
if not inf_nan_found:
optimizer.step()
self._update_scale()
# not using layerwise gradient scaling
else:
optimizer.step()
def _update_scale(self) -> None:
"""
For each layer, if an inf/nan is found, then multiply the scaling factor
of that layer by the backoff factor and set the growth tracker of that
layer to 0. Else, increment the growth tracker of the layer. If growth
tracker equals the growth interval, then multiply the scaling factor of
the layer by the growth factor and reset the layers' growth tracker to 0.
Finally, clip the scaling factor to the range
[self.min_scaling_factor, self.max_scaling_factor]. The min/max scaling
factor values are user configurable.
"""
if not self._apply_layerwise_scaling:
return
for layer in self.layer_info:
if layer.found_inf_or_nan:
if layer.scale_layer:
layer.scaling_factor = max(
self._min_scale,
min(self._backoff_factor * layer.scaling_factor, self._max_scale),
)
layer.growth_tracker = 0
else:
layer.growth_tracker += 1
if layer.scale_layer and layer.growth_tracker == self._growth_interval:
layer.scaling_factor = max(
self._min_scale,
min(self._growth_factor * layer.scaling_factor, self._max_scale),
)
layer.growth_tracker = 0
def get_layer_info(self) -> List[LayerInfo]:
"""
Returns a list of LayerInfo instances of the model.
"""
return self.layer_info
def get_backward_hooks(self) -> List:
"""
Returns a list of tuples. Each tuple contains the layer name and the
hook attached to it.
"""
layer_name_and_hooks = list()
for name, layer in self._model.named_modules():
if name != "":
layer_name_and_hooks.append((name, layer._get_backward_hooks()))
return layer_name_and_hooks
...@@ -27,4 +27,4 @@ use_parentheses = true ...@@ -27,4 +27,4 @@ use_parentheses = true
skip_glob = ["build/*", "stubs/*"] skip_glob = ["build/*", "stubs/*"]
# Don't split "import" and "from". # Don't split "import" and "from".
force_sort_within_sections = true force_sort_within_sections = true
known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "models", "numpy", "parameterized", "pytest", "recommonmark", "setuptools", "torch", "torchtext", "torchvision", "utils"] known_third_party = ["benchmark_dataset", "datasets", "golden_configs", "models", "numpy", "parameterized", "pytest", "recommonmark", "setuptools", "sklearn", "torch", "torchtext", "torchvision", "utils"]
...@@ -28,3 +28,6 @@ pynvml == 8.0.4 ...@@ -28,3 +28,6 @@ pynvml == 8.0.4
# For mypy typing # For mypy typing
numpy >= 1.21 numpy >= 1.21
# For layerwise gradient scaler
sklearn >= 0.0
...@@ -9,3 +9,4 @@ tests/nn/data_parallel/test_fsdp_input.py ...@@ -9,3 +9,4 @@ tests/nn/data_parallel/test_fsdp_input.py
tests/nn/data_parallel/test_fsdp_optimizer_utils.py tests/nn/data_parallel/test_fsdp_optimizer_utils.py
tests/nn/data_parallel/test_fsdp.py tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp_with_checkpoint_wrapper.py tests/nn/data_parallel/test_fsdp_with_checkpoint_wrapper.py
tests/optim/test_layerwise_gradient_scaler.py
import logging
from typing import Any, List, Tuple, Union
import numpy as np
import pytest
from sklearn.datasets import make_blobs
import torch
from torch.cuda.amp.autocast_mode import autocast
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from fairscale.optim.layerwise_gradient_scaler import LayerwiseGradientScaler
# Test: feed forward network
class FeedForward(torch.nn.Module):
def __init__(self, input_size: int, hidden_size: int):
torch.manual_seed(7)
super(FeedForward, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.fc1 = nn.Linear(self.input_size, self.hidden_size)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(self.hidden_size, self.hidden_size)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(self.hidden_size, 1)
self.sigmoid = nn.Sigmoid()
self.identity = nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
out = self.fc1(x)
out = self.relu1(out)
out = self.fc2(out)
out = self.relu2(out)
out = self.fc3(out)
out = self.sigmoid(out)
out = self.identity(out)
return out
# assign labels
def blob_label(y: np.ndarray, label: int, loc: List) -> np.ndarray:
target = np.copy(y) # type: ignore
for l in loc:
target[y == l] = label
return target
def load_data(model_type: str) -> Union[DataLoader, Tuple[Any, Any]]:
data = None
if model_type == "linear_model":
torch.manual_seed(11)
x_train, y_train = make_blobs(n_samples=40, n_features=2, cluster_std=1.5, shuffle=True, random_state=10)
x_train = torch.FloatTensor(x_train)
y_train = torch.FloatTensor(blob_label(y_train, 0, [0]))
y_train = torch.FloatTensor(blob_label(y_train, 1, [1, 2, 3]))
data = (x_train, y_train)
if model_type == "vision_model":
torch.manual_seed(10)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_ds = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
train_ds_loader = torch.utils.data.DataLoader(train_ds, batch_size=128, shuffle=False, num_workers=2)
image, _ = train_ds[0]
assert image.shape == torch.Size([3, 32, 32])
data = train_ds_loader # type: ignore
return data
def get_params_with_grad(trained_model):
result = []
for module_name, layer in trained_model.named_modules():
if module_name != "":
for param_name, param in layer.named_parameters():
if hasattr(param, "grad"):
logging.debug("testing equality for %s.%s" % (module_name, param_name))
result.append(param.grad)
return result
def train_linear_model(model: FeedForward, per_layer_scaling=False) -> FeedForward:
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
x_train, y_train = load_data("linear_model")
num_epochs = 2
model.train()
layers_to_scale = {"fc1": 1024, "fc2": 512, "fc3": 1024} if per_layer_scaling else {}
layer_scaler = LayerwiseGradientScaler(model, layers_to_scale)
for _ in range(num_epochs):
optimizer.zero_grad()
# scale the gradients
layer_scaler.scale()
with autocast():
# forward pass
y_pred = model(x_train)
# compute loss
loss = criterion(y_pred.squeeze(), y_train)
loss.backward()
# unscale the gradients
layer_scaler.unscale()
# update weights and scaling factor
layer_scaler.step(optimizer)
return model
def test_linear_model() -> None:
model1 = FeedForward(2, 10)
model2 = FeedForward(2, 10)
vanilla_model = train_linear_model(model1, False)
scaled_model = train_linear_model(model2, True)
for elt in zip(get_params_with_grad(vanilla_model), get_params_with_grad(scaled_model)):
assert torch.allclose(elt[0], elt[1])
# Test: convolutional network
class SimpleConvNet(nn.Module):
def __init__(self):
torch.manual_seed(24)
super(SimpleConvNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(120, 84)
self.relu4 = nn.ReLU()
self.fc3 = nn.Linear(84, 10)
self.identity = nn.Identity()
def forward(self, x):
out = self.conv1(x)
out = self.relu1(out)
out = self.pool1(out)
out = self.conv2(out)
out = self.relu2(out)
out = self.pool2(out)
out = torch.flatten(out, 1) # flatten all dimensions except batch
out = self.fc1(out)
out = self.relu3(out)
out = self.fc2(out)
out = self.relu4(out)
out = self.fc3(out)
out = self.identity(out)
return out
def train_vision_model(model: SimpleConvNet, per_layer_scaling=False):
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
if torch.cuda.is_available():
model.cuda()
train_ds_loader = load_data("vision_model")
model.train()
layer_scale_dict = {"conv1": 128, "conv2": 256, "fc1": 512, "fc2": 1024, "fc3": 8192} if per_layer_scaling else {}
layer_scaler = LayerwiseGradientScaler(model, layer_scale_dict)
for _ in range(2):
for img, lbl in train_ds_loader:
if torch.cuda.is_available():
img = img.cuda()
lbl = lbl.cuda()
optimizer.zero_grad()
layer_scaler.scale()
predict = model(img)
loss = loss_fn(predict, lbl)
loss.backward()
layer_scaler.unscale()
layer_scaler.step(optimizer)
return model
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_vision_model() -> None:
# Remove randomness from various sources while testing.
torch.use_deterministic_algorithms(True) # type: ignore
# set environment variable in CircleCI for test to pass: CUBLAS_WORKSPACE_CONFIG = :4096:8
m1 = SimpleConvNet()
m2 = SimpleConvNet()
vision_model = train_vision_model(m1, False)
scaled_vision_model = train_vision_model(m2, True)
for elt in zip(get_params_with_grad(vision_model), get_params_with_grad(scaled_vision_model)):
assert torch.allclose(elt[0], elt[1])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment