Unverified Commit dcb02fcf authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Gradient clipping with fused kernels (#1405)

* Gradient clipping routine with fused kernels

Identical API as PyTorch. Falls back to PyTorch impl when not computing L2 norm.

* Add unit test for gradient clipping

* Add fp16 case to gradient clipping unit test

* Tweaks to grad clipping unit test

Review suggestions from @crcrpar

* Debug gradient clipping tests

When checking that incorrect results produce assertion errors, make sure to generate a discrepancy outside the range of numerical error.
parent 1403c21a
from .clip_grad import clip_grad_norm_
import torch
from torch._six import inf
from typing import Union, Iterable
_kernel_import_succeeded = False
try:
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
_kernel_import_succeeded = True
except:
_kernel_import_succeeded = False
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
def clip_grad_norm_(
parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
error_if_nonfinite: bool = False) -> torch.Tensor:
r"""Clips gradient norm of an iterable of parameters.
The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
This is identical to torch.nn.utils.clip_grad_norm_, except it
uses a fused CUDA kernel when computing the 2-norm of GPU tensors
in float32 and float16.
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
error_if_nonfinite (bool): if True, an error is thrown if the total
norm of the gradients from :attr:`parameters` is ``nan``,
``inf``, or ``-inf``. Default: False (will switch to True in the future)
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)
# Trivial case
if len(parameters) == 0:
return torch.tensor(0.)
# Fallback implementation
if not (_kernel_import_succeeded
and norm_type == 2.0
and any(p.is_cuda for p in parameters)):
return torch.nn.utils.clip_grad_norm_(
parameters,
max_norm,
norm_type=norm_type,
error_if_nonfinite = error_if_nonfinite,
)
# Find fp32 and fp16 gradients on GPU
device = next(p.device for p in parameters if p.is_cuda)
grads_fp32, grads_fp16, grads_misc = [], [], []
for p in parameters:
grad = p.grad.detach()
if p.dtype == torch.float32 and p.device == device:
grads_fp32.append(grad)
elif p.dtype == torch.float16 and p.device == device:
grads_fp16.append(grad)
else:
grads_misc.append(grad)
# Compute gradient L2 norms
norms = []
dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device=device)
if grads_fp32:
norms.append(
multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads_fp32],
False,
)[0]
)
if grads_fp16:
norms.append(
multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads_fp16],
False,
)[0],
)
for g in grads_misc:
norms.append(torch.linalg.norm(g).unsqueeze(0).to(device))
total_norm = torch.linalg.norm(torch.cat(norms))
# Check for non-finite values
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f'The total norm of order {norm_type} for gradients from '
'`parameters` is non-finite, so it cannot be clipped. To disable '
'this error and scale the gradients by the non-finite norm anyway, '
'set `error_if_nonfinite=False`')
# Scale gradients
clip_coef = max_norm / (total_norm + 1e-6)
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
if grads_fp32:
multi_tensor_applier(
amp_C.multi_tensor_scale,
dummy_overflow_buf,
[grads_fp32, grads_fp32],
clip_coef_clamped,
)
if grads_fp16:
multi_tensor_applier(
amp_C.multi_tensor_scale,
dummy_overflow_buf,
[grads_fp16, grads_fp16],
clip_coef_clamped,
)
for g in grads_misc:
g.mul_(clip_coef_clamped.to(g.device))
return total_norm
import random
import unittest
import torch
from apex.contrib.clip_grad import clip_grad_norm_
def make_params(
num_params,
sizes=[1,2,3,4,5],
num_dims=[1,2,3],
dtypes=[torch.float32],
devices=['cuda'],
make_copy=False,
):
"""Construct parameters with random configurations"""
# Construct parameters
params = []
for _ in range(num_params):
dims = [random.choice(sizes) for _ in range(random.choice(num_dims))]
dtype = random.choice(dtypes)
device = random.choice(devices)
p = torch.nn.Parameter(torch.randn(dims, dtype=dtype, device=device))
p.grad = torch.randn_like(p)
params.append(p)
# Copy parameters if needed
if make_copy:
params_copy = []
for p in params:
p_copy = p.clone().detach()
p_copy.grad = p.grad.clone().detach()
params_copy.append(p_copy)
return params, params_copy
else:
return params
class ClipGradNormTest(unittest.TestCase):
def setUp(self, seed=1234):
random.seed(seed)
torch.manual_seed(seed)
def test_matches_pytorch(
self,
num_params=41,
dtypes=[torch.float32, torch.float16, torch.float64],
devices=['cuda', 'cpu'],
max_norm=0.54321,
norm_type=2.0,
rtol=1e-3,
atol=1e-20,
):
"""Make sure PyTorch and Apex gradient clipping produce same results"""
# Construct identical sets of parameters
torch_params, apex_params = make_params(
num_params,
dtypes=dtypes,
devices=devices,
make_copy=True,
)
# Apply gradient clipping
torch_norm = torch.nn.utils.clip_grad_norm_(
torch_params,
max_norm,
norm_type=norm_type,
)
apex_norm = clip_grad_norm_(
apex_params,
max_norm,
norm_type=norm_type,
)
# Make sure PyTorch and Apex get same results
torch.testing.assert_close(
apex_norm, torch_norm,
rtol=rtol,
atol=atol,
check_dtype=False,
)
for torch_p, apex_p in zip(torch_params, apex_params):
torch.testing.assert_close(
apex_p, torch_p,
rtol=0,
atol=0,
) # Params should be unaffected
torch.testing.assert_close(
apex_p.grad, torch_p.grad,
rtol=rtol,
atol=atol,
)
def test_matches_pytorch_fp16(self):
self.test_matches_pytorch(num_params=11, dtypes=[torch.float16])
def test_matches_pytorch_fp32(self):
self.test_matches_pytorch(dtypes=[torch.float32], rtol=1e-6)
def test_matches_pytorch_fp64(self):
self.test_matches_pytorch(dtypes=[torch.float64], rtol=1e-15)
def test_matches_pytorch_cpu(self):
self.test_matches_pytorch(devices=['cpu'])
def test_matches_pytorch_infnorm(self):
self.test_matches_pytorch(norm_type=float('inf'))
def test_matches_pytorch_1norm(self):
self.test_matches_pytorch(norm_type=1.0)
def test_raises_on_mismatch(self):
# Construct different sets of parameters
torch_params, apex_params = make_params(7, make_copy=True)
with torch.no_grad():
torch_params[0].grad.view(-1)[0] = 1.23
apex_params[0].grad.view(-1)[0] = 3.21
# Apply gradient clipping
torch_norm = torch.nn.utils.clip_grad_norm_(
torch_params,
0.54321,
)
apex_norm = clip_grad_norm_(
apex_params,
0.54321,
)
# Make sure PyTorch and Apex get different results
self.assertRaises(
AssertionError,
torch.testing.assert_close,
apex_norm, torch_norm,
rtol=1e-3,
atol=1e-20,
check_dtype=False,
)
for torch_p, apex_p in zip(torch_params, apex_params):
self.assertRaises(
AssertionError,
torch.testing.assert_close,
apex_p.grad, torch_p.grad,
rtol=1e-3,
atol=1e-20,
)
def test_raises_on_nan(self):
params = make_params(5, num_dims=[1])
params[2].grad[-1] = float('NaN')
self.assertRaises(
RuntimeError, clip_grad_norm_, params, 1.0, error_if_nonfinite=True)
def test_raises_on_inf(self):
params = make_params(5, num_dims=[1])
params[2].grad[-1] = float('inf')
self.assertRaises(
RuntimeError, clip_grad_norm_, params, 1.0, error_if_nonfinite=True)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment