clip_grads.py 5.01 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
mohammad's avatar
mohammad committed
2
3
4
5
6
7
8
9
10

"""Gradient clipping."""

import torch
from torch._six import inf

from apex.multi_tensor_apply import multi_tensor_applier
import amp_C

mohammad's avatar
mohammad committed
11
12
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
mohammad's avatar
mohammad committed
13
14


15
16
def clip_grad_norm_fp32(parameters, grads_for_norm,
                        max_norm, norm_type=2,
Lawrence McAfee's avatar
Lawrence McAfee committed
17
                        model_parallel_group=None):
mohammad's avatar
mohammad committed
18
19
20
21
22
23
24
25
26
27
    """Clips gradient norm of an iterable of parameters whose gradients
       are in fp32.

    This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
    added functionality to handle model parallel parameters. Note that
    the gradients are modified in place.

    Arguments:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
Lawrence McAfee's avatar
Lawrence McAfee committed
28
29
        grads_for_norm (Iterable[Tensor]): an iterable of Tensors or a single
            Tensor that will be used for calculating the grad norm.
mohammad's avatar
mohammad committed
30
31
32
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.
Lawrence McAfee's avatar
Lawrence McAfee committed
33
        model_parallel_group (group): given the nature of the distributed
34
            optimizer, this is passed as an argument.
mohammad's avatar
mohammad committed
35
36
37
38
39
40
41

    Returns:
        Total norm of the parameters (viewed as a single vector).
    """

    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
Lawrence McAfee's avatar
Lawrence McAfee committed
42
43
    if isinstance(grads_for_norm, torch.Tensor):
        grads_for_norm = [grads_for_norm]
mohammad's avatar
mohammad committed
44

45
    # Grads.
46
47
48
49
50
    grads = []
    for param in parameters:
        if param.grad is not None:
            assert param.grad.type() == 'torch.cuda.FloatTensor'
            grads.append(param.grad.detach())
51

mohammad's avatar
mohammad committed
52
53
54
55
56
57
58
59
60
61
62
63
    # Norm parameters.
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    total_norm = 0.0

    # Calculate norm.
    if norm_type == inf:
        total_norm = max(grad.abs().max() for grad in grads_for_norm)
        total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
        # Take max across all model-parallel GPUs.
        torch.distributed.all_reduce(total_norm_cuda,
                                     op=torch.distributed.ReduceOp.MAX,
64
                                     group=model_parallel_group)
mohammad's avatar
mohammad committed
65
66
67
68
69
        total_norm = total_norm_cuda[0].item()

    else:
        if norm_type == 2.0:
            dummy_overflow_buf = torch.cuda.IntTensor([0])
70
71
72
            # Use apex's multi-tensor applier for efficiency reasons.
            # Multi-tensor applier takes a function and a list of list
            # and performs the operation on that list all in one kernel.
73
74
75
76
77
78
79
80
81
            if grads_for_norm:
                grad_norm, _ = multi_tensor_applier(
                    amp_C.multi_tensor_l2norm,
                    dummy_overflow_buf,
                    [grads_for_norm],
                    False # no per-parameter norm
                )
            else:
                grad_norm = torch.cuda.FloatTensor([0])
mohammad's avatar
mohammad committed
82
83
            # Since we will be summing across data parallel groups,
            # we need the pow(norm-type).
mohammad's avatar
mohammad committed
84
85
86
87
88
89
90
91
            total_norm = grad_norm ** norm_type

        else:
            for grad in grads_for_norm:
                grad_norm = torch.norm(grad, norm_type)
                total_norm += grad_norm ** norm_type

        # Sum across all model-parallel GPUs.
92
93
94
        torch.distributed.all_reduce(total_norm,
                                     op=torch.distributed.ReduceOp.SUM,
                                     group=model_parallel_group)
mohammad's avatar
mohammad committed
95
96
97
98
99
100
101
102
103
104
105
106
        total_norm = total_norm.item() ** (1.0 / norm_type)

    # Scale.
    clip_coeff = max_norm / (total_norm + 1.0e-6)
    if clip_coeff < 1.0:
        dummy_overflow_buf = torch.cuda.IntTensor([0])
        multi_tensor_applier(amp_C.multi_tensor_scale,
                             dummy_overflow_buf,
                             [grads, grads],
                             clip_coeff)

    return total_norm
Rewon Child's avatar
Rewon Child committed
107
108


109
def count_zeros_fp32(parameters, model_parallel_group):
Rewon Child's avatar
Rewon Child committed
110
111
112
113
114
115
116
117

    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]

    # Filter parameters based on:
    #   - grad should not be none
    #   - parameter should not be shared
    #   - should not be a replica due to tensor model parallelism
Lawrence McAfee's avatar
Lawrence McAfee committed
118
    total_num_zeros = torch.cuda.FloatTensor([0.0])
Rewon Child's avatar
Rewon Child committed
119
120
121
122
123
124
    for param in parameters:
        grad_not_none = param.grad is not None
        is_not_shared = param_is_not_shared(param)
        is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
        if grad_not_none and is_not_shared and is_not_tp_duplicate:
            grad = param.grad.detach()
Rewon Child's avatar
Rewon Child committed
125
126
            num_zeros = grad.numel() - torch.count_nonzero(grad)
            total_num_zeros = num_zeros + total_num_zeros
Rewon Child's avatar
Rewon Child committed
127
128

    # Sum across all model-parallel GPUs.
129
130
131
    torch.distributed.all_reduce(total_num_zeros,
                                 op=torch.distributed.ReduceOp.SUM,
                                 group=model_parallel_group)
132

Rewon Child's avatar
Rewon Child committed
133
134
135
    total_num_zeros = total_num_zeros.item()

    return total_num_zeros