utils.py 7.97 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Raul Puri's avatar
Raul Puri committed
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
"""General utilities."""
Raul Puri's avatar
Raul Puri committed
17

18
import sys
19

20
import torch
21
from torch.nn.parallel import DistributedDataParallel as torchDDP
22

mohammad's avatar
mohammad committed
23
24
25
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C

Neel Kant's avatar
Neel Kant committed
26
from megatron import get_args
27
from megatron import get_adlr_autoresume
28
from megatron import mpu
mohammad's avatar
mohammad committed
29
30
31
32
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate


33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def unwrap_model(model, module_instances=(torchDDP)):
    return_list = True
    if not isinstance(model, list):
        model = [model]
        return_list = False
    unwrapped_model = []
    for model_module in model:
        while isinstance(model_module, module_instances):
            model_module = model_module.module
        unwrapped_model.append(model_module)
    if not return_list:
        return unwrapped_model[0]
    return unwrapped_model


mohammad's avatar
mohammad committed
48
49
def calc_params_l2_norm(model):
    """Calculate l2 norm of parameters """
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
50
51
52
    args = get_args()
    if not isinstance(model, list):
        model = [model]
mohammad's avatar
mohammad committed
53
54
    # Remove duplicate params.
    params_data = []
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
55
56
57
58
59
60
61
62
63
    for model_ in model:
        for param in model_.parameters():
            is_not_shared = param_is_not_shared(param)
            is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
            if is_not_shared and is_not_tp_duplicate:
                if args.bf16:
                    params_data.append(param.data.float())
                else:
                    params_data.append(param.data)
mohammad's avatar
mohammad committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    # Calculate norm
    dummy_overflow_buf = torch.cuda.IntTensor([0])
    norm, _ = multi_tensor_applier(
        amp_C.multi_tensor_l2norm,
        dummy_overflow_buf,
        [params_data],
        False # no per-parameter norm
    )
    norm_2 = norm * norm
    # Sum across all model-parallel GPUs.
    torch.distributed.all_reduce(norm_2,
                                 op=torch.distributed.ReduceOp.SUM,
                                 group=mpu.get_model_parallel_group())
    return norm_2.item() ** 0.5
78

Mohammad's avatar
Mohammad committed
79

80
def average_losses_across_data_parallel_group(losses):
Mohammad's avatar
Mohammad committed
81
    """Reduce a tensor of losses across all GPUs."""
82
    averaged_losses = torch.cat(
Mohammad's avatar
Mohammad committed
83
        [loss.clone().detach().view(1) for loss in losses])
84
85
86
87
    torch.distributed.all_reduce(averaged_losses,
                                 group=mpu.get_data_parallel_group())
    averaged_losses = averaged_losses / \
        torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
Mohammad's avatar
Mohammad committed
88

89
    return averaged_losses
Mohammad's avatar
Mohammad committed
90
91


92
93
94
95
96
97
98
99
def report_memory(name):
    """Simple GPU memory report."""
    mega_bytes = 1024.0 * 1024.0
    string = name + ' memory (MB)'
    string += ' | allocated: {}'.format(
        torch.cuda.memory_allocated() / mega_bytes)
    string += ' | max allocated: {}'.format(
        torch.cuda.max_memory_allocated() / mega_bytes)
100
101
    string += ' | reserved: {}'.format(
        torch.cuda.memory_reserved() / mega_bytes)
102
103
    string += ' | max reserved: {}'.format(
        torch.cuda.max_memory_reserved() / mega_bytes)
104
    if mpu.get_data_parallel_rank() == 0:
105
106
        print("[Rank {}] {}".format(torch.distributed.get_rank(), string),
              flush=True)
107
108
109
110
111
112


def print_params_min_max_norm(optimizer, iteration):
    """Print min, max, and norm of all parameters."""
    index = 0
    rank = torch.distributed.get_rank()
113
    string = 'iteration, rank, index, tensor-model-parallel, min, max, norm\n'
114
    optimizer_ = optimizer.optimizer
115
116
117
118
119
    for param_group in optimizer_.param_groups:
        for param in param_group['params']:
            index += 1
            min_ = param.data.min()
            max_ = param.data.max()
120
            norm = torch.linalg.norm(param.data)
121
            string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
122
                iteration, rank, index, int(param.tensor_model_parallel))
123
124
125
126
            string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
    print(string, flush=True)


127
def check_adlr_autoresume_termination(iteration, model,
128
                                      optimizer, opt_param_scheduler):
129
    """Check for autoresume signal and exit if it is received."""
130
131
    from megatron.checkpointing import save_checkpoint

132
133
    args = get_args()
    autoresume = get_adlr_autoresume()
Mohammad's avatar
Mohammad committed
134
    # Add barrier to ensure consistnecy.
135
    torch.distributed.barrier()
136
    if autoresume.termination_requested():
Mohammad's avatar
Mohammad committed
137
        if args.save:
138
            save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
Mohammad's avatar
Mohammad committed
139
140
        print_rank_0(">>> autoresume termination request found!")
        if torch.distributed.get_rank() == 0:
141
            autoresume.request_resume()
Mohammad's avatar
Mohammad committed
142
        print_rank_0(">>> training terminated. Returning")
143
144
145
        sys.exit(0)


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
146
147
148
149
def get_ltor_masks_and_position_ids(data,
                                    eod_token,
                                    reset_position_ids,
                                    reset_attention_mask,
150
                                    eod_mask_loss):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
151
152
153
    """Build masks and position id for left to right model."""

    # Extract batch size and sequence length.
154
    micro_batch_size, seq_length = data.size()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
155
156
157

    # Attention mask (lower triangular).
    if reset_attention_mask:
158
        att_mask_batch = micro_batch_size
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    else:
        att_mask_batch = 1
    attention_mask = torch.tril(torch.ones(
        (att_mask_batch, seq_length, seq_length), device=data.device)).view(
            att_mask_batch, 1, seq_length, seq_length)

    # Loss mask.
    loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
    if eod_mask_loss:
        loss_mask[data == eod_token] = 0.0

    # Position ids.
    position_ids = torch.arange(seq_length, dtype=torch.long,
                                device=data.device)
    position_ids = position_ids.unsqueeze(0).expand_as(data)
    # We need to clone as the ids will be modifed based on batch index.
    if reset_position_ids:
        position_ids = position_ids.clone()

    if reset_position_ids or reset_attention_mask:
        # Loop through the batches:
180
        for b in range(micro_batch_size):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
181
182
183
184
185
186
187
188
189
190
191
192
193

            # Find indecies where EOD token is.
            eod_index = position_ids[b, data[b] == eod_token]
            # Detach indecies from positions if going to modify positions.
            if reset_position_ids:
                eod_index = eod_index.clone()

            # Loop through EOD indecies:
            prev_index = 0
            for j in range(eod_index.size()[0]):
                i = eod_index[j]
                # Mask attention loss.
                if reset_attention_mask:
Neel Kant's avatar
Neel Kant committed
194
                    attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
195
196
                # Reset positions.
                if reset_position_ids:
Neel Kant's avatar
Neel Kant committed
197
                    position_ids[b, (i + 1):] -= (i + 1 - prev_index)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
198
199
                    prev_index = i + 1

200
201
    # Convert attention mask to binary:
    attention_mask = (attention_mask < 0.5)
Mohammad's avatar
Mohammad committed
202

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
203
    return attention_mask, loss_mask, position_ids
204
205


206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def print_rank_0(message):
    """If distributed is initialized, print only on rank 0."""
    if torch.distributed.is_initialized():
        if torch.distributed.get_rank() == 0:
            print(message, flush=True)
    else:
        print(message, flush=True)

def is_last_rank():
    return torch.distributed.get_rank() == (
        torch.distributed.get_world_size() - 1)

def print_rank_last(message):
    """If distributed is initialized, print only on last rank."""
    if torch.distributed.is_initialized():
        if is_last_rank():
            print(message, flush=True)
    else:
        print(message, flush=True)