clip_grads.py 8.12 KB
Newer Older
mohammad's avatar
mohammad committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Gradient clipping."""

import torch
from torch._six import inf

from apex.multi_tensor_apply import multi_tensor_applier
import amp_C

from megatron import mpu
mohammad's avatar
mohammad committed
25
26
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
mohammad's avatar
mohammad committed
27
28


Lawrence McAfee's avatar
Lawrence McAfee committed
29
30
# >>>
from lutil import pax, tp
Lawrence McAfee's avatar
Lawrence McAfee committed
31
DEBUG_ITERATION = 1
Lawrence McAfee's avatar
Lawrence McAfee committed
32
33
# <<<

Lawrence McAfee's avatar
Lawrence McAfee committed
34
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
mohammad's avatar
mohammad committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    """Clips gradient norm of an iterable of parameters whose gradients
       are in fp32.

    This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
    added functionality to handle model parallel parameters. Note that
    the gradients are modified in place.

    Arguments:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.

    Returns:
        Total norm of the parameters (viewed as a single vector).
    """

Lawrence McAfee's avatar
Lawrence McAfee committed
53
54
55
56
    # >>>
    raise Exception("currently debugging ... don't call me.")
    # <<<

mohammad's avatar
mohammad committed
57
58
59
60
61
62
63
64
65
66
67
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]

    # Filter parameters based on:
    #   - grad should not be none
    #   - parameter should not be shared
    #   - should not be a replica due to tensor model parallelism
    grads = []
    grads_for_norm = []
    for param in parameters:
        grad_not_none = param.grad is not None
mohammad's avatar
mohammad committed
68
69
        is_not_shared = param_is_not_shared(param)
        is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
70
71
        if grad_not_none:
            grad = param.grad.detach()
mohammad's avatar
mohammad committed
72
        if grad_not_none:
73
74
            # Make sure the grads are in fp32
            assert param.grad.type() == 'torch.cuda.FloatTensor'
mohammad's avatar
mohammad committed
75
76
77
            grads.append(grad)
        if grad_not_none and is_not_shared and is_not_tp_duplicate:
            grads_for_norm.append(grad)
Lawrence McAfee's avatar
Lawrence McAfee committed
78
79
80
81
82
83
84
85
86
        # >>>
        # else:
        #     pax(1, {
        #         "grad_not_none" : grad_not_none,
        #         "is_not_shared" : is_not_shared,
        #         "is_not_tp_duplicate" : is_not_tp_duplicate,
        #     })
        # <<<

87
    # >>>
Lawrence McAfee's avatar
Lawrence McAfee committed
88
89
90
91
92
93
94
95
96
    # if ITERATION == DEBUG_ITERATION:
    #     pax(0, {
    #         "[LOC]" : "[** BEFORE CALC NORM **]",
    #         "[ITERATION]" : ITERATION,
    #         "max_norm" : max_norm,
    #         "parameters" : parameters,
    #         # "grads" : grads,
    #         "grads_for_norm" : grads_for_norm,
    #     })
97
    # <<<
mohammad's avatar
mohammad committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

    # Norm parameters.
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    total_norm = 0.0

    # Calculate norm.
    if norm_type == inf:
        total_norm = max(grad.abs().max() for grad in grads_for_norm)
        total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
        # Take max across all model-parallel GPUs.
        torch.distributed.all_reduce(total_norm_cuda,
                                     op=torch.distributed.ReduceOp.MAX,
                                     group=mpu.get_model_parallel_group())
        total_norm = total_norm_cuda[0].item()

    else:
        if norm_type == 2.0:
            dummy_overflow_buf = torch.cuda.IntTensor([0])
117
118
119
            # Use apex's multi-tensor applier for efficiency reasons.
            # Multi-tensor applier takes a function and a list of list
            # and performs the operation on that list all in one kernel.
Lawrence McAfee's avatar
Lawrence McAfee committed
120
121
122
123
124
125
126
            # >>>
            # pax(1, {
            #     # "fn" : amp_C.multi_tensor_l2norm,
            #     "dummy_overflow_buf" : tp(dummy_overflow_buf),
            #     "grads_for_norm" : grads_for_norm,
            # })
            # <<<
mohammad's avatar
mohammad committed
127
128
129
130
131
132
            grad_norm, _ = multi_tensor_applier(
                amp_C.multi_tensor_l2norm,
                dummy_overflow_buf,
                [grads_for_norm],
                False # no per-parameter norm
            )
mohammad's avatar
mohammad committed
133
134
            # Since we will be summing across data parallel groups,
            # we need the pow(norm-type).
mohammad's avatar
mohammad committed
135
136
137
138
139
140
141
            total_norm = grad_norm ** norm_type

        else:
            for grad in grads_for_norm:
                grad_norm = torch.norm(grad, norm_type)
                total_norm += grad_norm ** norm_type

142
        # >>>
Lawrence McAfee's avatar
Lawrence McAfee committed
143
144
145
146
147
148
149
150
151
        # if ITERATION == DEBUG_ITERATION:
        #     pax(0, {
        #         "[LOC]" : "[** CALC NORM **]",
        #         "[ITERATION]" : ITERATION,
        #         "max_norm" : max_norm,
        #         "norm_type" : norm_type,
        #         "grad_norm" : tp(grad_norm),
        #         "total_norm" : tp(total_norm),
        #     })
152
153
        # <<<

mohammad's avatar
mohammad committed
154
        # Sum across all model-parallel GPUs.
155
        # >>>
Lawrence McAfee's avatar
Lawrence McAfee committed
156
157
158
159
160
161
        from megatron import get_args
        args = get_args()
        if not args.use_distributed_optimizer:
            torch.distributed.all_reduce(total_norm,
                                         op=torch.distributed.ReduceOp.SUM,
                                         group=mpu.get_model_parallel_group())
162
        # +++
Lawrence McAfee's avatar
Lawrence McAfee committed
163
164
165
        else:
            torch.distributed.all_reduce(total_norm,
                                         op=torch.distributed.ReduceOp.SUM)
166
        # <<<
mohammad's avatar
mohammad committed
167
168
        total_norm = total_norm.item() ** (1.0 / norm_type)

169
        # >>>
Lawrence McAfee's avatar
Lawrence McAfee committed
170
171
172
173
174
175
176
177
178
        # if ITERATION == DEBUG_ITERATION:
        #     pax(0, {
        #         "[LOC]" : "[** AFTER REDUCE. **]",
        #         "[ITERATION]" : ITERATION,
        #         "max_norm" : max_norm,
        #         "norm_type" : norm_type,
        #         "grad_norm" : grad_norm.item(),
        #         "total_norm" : total_norm,
        #     })
179
180
        # <<<

mohammad's avatar
mohammad committed
181
182
183
184
185
186
187
188
189
    # Scale.
    clip_coeff = max_norm / (total_norm + 1.0e-6)
    if clip_coeff < 1.0:
        dummy_overflow_buf = torch.cuda.IntTensor([0])
        multi_tensor_applier(amp_C.multi_tensor_scale,
                             dummy_overflow_buf,
                             [grads, grads],
                             clip_coeff)

Lawrence McAfee's avatar
Lawrence McAfee committed
190
191
192
193
194
195
196
197
198
199
200
201
    # >>>
    # # from pygit2 import Repository
    # if ITERATION == DEBUG_ITERATION:
    #     pax(1, {
    #         "[LOC]" : "[** CLIP / FINAL **]",
    #         "[ITERATION]" : ITERATION,
    #         "grads" : grads,
    #         "clip_coeff" : tp(clip_coeff),
    #         # "repo" : Repository('.').head.shorthand,
    #     })
    # <<<

mohammad's avatar
mohammad committed
202
    return total_norm
Rewon Child's avatar
Rewon Child committed
203
204
205
206
207
208
209
210
211
212
213


def count_zeros_fp32(parameters):

    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]

    # Filter parameters based on:
    #   - grad should not be none
    #   - parameter should not be shared
    #   - should not be a replica due to tensor model parallelism
Rewon Child's avatar
Rewon Child committed
214
    total_num_zeros = 0.0
Rewon Child's avatar
Rewon Child committed
215
216
217
218
219
220
    for param in parameters:
        grad_not_none = param.grad is not None
        is_not_shared = param_is_not_shared(param)
        is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
        if grad_not_none and is_not_shared and is_not_tp_duplicate:
            grad = param.grad.detach()
Rewon Child's avatar
Rewon Child committed
221
222
            num_zeros = grad.numel() - torch.count_nonzero(grad)
            total_num_zeros = num_zeros + total_num_zeros
Rewon Child's avatar
Rewon Child committed
223
224

    # Sum across all model-parallel GPUs.
225
226
227
228
229
230
231
232
233
234
    # >>>
    if args.use_distributed_optimizer:
        torch.distributed.all_reduce(total_num_zeros,
                                     op=torch.distributed.ReduceOp.SUM)
    else:
        torch.distributed.all_reduce(total_num_zeros,
                                     op=torch.distributed.ReduceOp.SUM,
                                     group=mpu.get_model_parallel_group())
    # <<<

Rewon Child's avatar
Rewon Child committed
235
236
237
    total_num_zeros = total_num_zeros.item()

    return total_num_zeros