utils.py 6.63 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Raul Puri's avatar
Raul Puri committed
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
"""General utilities."""
Raul Puri's avatar
Raul Puri committed
17

18
import sys
19

20
import torch
21

mohammad's avatar
mohammad committed
22
23
24
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C

Neel Kant's avatar
Neel Kant committed
25
26
from megatron import get_args
from megatron import print_rank_0
27
from megatron import get_adlr_autoresume
28
from megatron import mpu
29
from megatron.checkpointing import save_checkpoint
mohammad's avatar
mohammad committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate


def calc_params_l2_norm(model):
    """Calculate l2 norm of parameters """
    # Remove duplicate params.
    params_data = []
    for param in model.parameters():
        is_not_shared = param_is_not_shared(param)
        is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
        if is_not_shared and is_not_tp_duplicate:
            params_data.append(param.data)
    # Calculate norm
    dummy_overflow_buf = torch.cuda.IntTensor([0])
    norm, _ = multi_tensor_applier(
        amp_C.multi_tensor_l2norm,
        dummy_overflow_buf,
        [params_data],
        False # no per-parameter norm
    )
    norm_2 = norm * norm
    # Sum across all model-parallel GPUs.
    torch.distributed.all_reduce(norm_2,
                                 op=torch.distributed.ReduceOp.SUM,
                                 group=mpu.get_model_parallel_group())
    return norm_2.item() ** 0.5
57

Mohammad's avatar
Mohammad committed
58

59
def average_losses_across_data_parallel_group(losses):
Mohammad's avatar
Mohammad committed
60
    """Reduce a tensor of losses across all GPUs."""
61
    averaged_losses = torch.cat(
Mohammad's avatar
Mohammad committed
62
        [loss.clone().detach().view(1) for loss in losses])
63
64
65
66
    torch.distributed.all_reduce(averaged_losses,
                                 group=mpu.get_data_parallel_group())
    averaged_losses = averaged_losses / \
        torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
Mohammad's avatar
Mohammad committed
67

68
    return averaged_losses
Mohammad's avatar
Mohammad committed
69
70


71
72
73
74
75
76
77
78
def report_memory(name):
    """Simple GPU memory report."""
    mega_bytes = 1024.0 * 1024.0
    string = name + ' memory (MB)'
    string += ' | allocated: {}'.format(
        torch.cuda.memory_allocated() / mega_bytes)
    string += ' | max allocated: {}'.format(
        torch.cuda.max_memory_allocated() / mega_bytes)
79
80
    string += ' | reserved: {}'.format(
        torch.cuda.memory_reserved() / mega_bytes)
81
82
    string += ' | max reserved: {}'.format(
        torch.cuda.max_memory_reserved() / mega_bytes)
83
    if mpu.get_data_parallel_rank() == 0:
84
85
        print("[Rank {}] {}".format(torch.distributed.get_rank(), string),
              flush=True)
86
87
88
89
90
91


def print_params_min_max_norm(optimizer, iteration):
    """Print min, max, and norm of all parameters."""
    index = 0
    rank = torch.distributed.get_rank()
92
    string = 'iteration, rank, index, tensor-model-parallel, min, max, norm\n'
93
    optimizer_ = optimizer.optimizer
94
95
96
97
98
    for param_group in optimizer_.param_groups:
        for param in param_group['params']:
            index += 1
            min_ = param.data.min()
            max_ = param.data.max()
99
            norm = torch.linalg.norm(param.data)
100
            string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
101
                iteration, rank, index, int(param.tensor_model_parallel))
102
103
104
105
            string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
    print(string, flush=True)


106
107
108
109
110
def check_adlr_autoresume_termination(iteration, model,
                                      optimizer, lr_scheduler):
    """Check for autoresume signal and exit if it is received."""
    args = get_args()
    autoresume = get_adlr_autoresume()
Mohammad's avatar
Mohammad committed
111
    # Add barrier to ensure consistnecy.
112
    torch.distributed.barrier()
113
    if autoresume.termination_requested():
Mohammad's avatar
Mohammad committed
114
        if args.save:
115
            save_checkpoint(iteration, model, optimizer, lr_scheduler)
Mohammad's avatar
Mohammad committed
116
117
        print_rank_0(">>> autoresume termination request found!")
        if torch.distributed.get_rank() == 0:
118
            autoresume.request_resume()
Mohammad's avatar
Mohammad committed
119
        print_rank_0(">>> training terminated. Returning")
120
121
122
        sys.exit(0)


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
123
124
125
126
def get_ltor_masks_and_position_ids(data,
                                    eod_token,
                                    reset_position_ids,
                                    reset_attention_mask,
127
                                    eod_mask_loss):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
128
129
130
    """Build masks and position id for left to right model."""

    # Extract batch size and sequence length.
131
    micro_batch_size, seq_length = data.size()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
132
133
134

    # Attention mask (lower triangular).
    if reset_attention_mask:
135
        att_mask_batch = micro_batch_size
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    else:
        att_mask_batch = 1
    attention_mask = torch.tril(torch.ones(
        (att_mask_batch, seq_length, seq_length), device=data.device)).view(
            att_mask_batch, 1, seq_length, seq_length)

    # Loss mask.
    loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
    if eod_mask_loss:
        loss_mask[data == eod_token] = 0.0

    # Position ids.
    position_ids = torch.arange(seq_length, dtype=torch.long,
                                device=data.device)
    position_ids = position_ids.unsqueeze(0).expand_as(data)
    # We need to clone as the ids will be modifed based on batch index.
    if reset_position_ids:
        position_ids = position_ids.clone()

    if reset_position_ids or reset_attention_mask:
        # Loop through the batches:
157
        for b in range(micro_batch_size):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
158
159
160
161
162
163
164
165
166
167
168
169
170

            # Find indecies where EOD token is.
            eod_index = position_ids[b, data[b] == eod_token]
            # Detach indecies from positions if going to modify positions.
            if reset_position_ids:
                eod_index = eod_index.clone()

            # Loop through EOD indecies:
            prev_index = 0
            for j in range(eod_index.size()[0]):
                i = eod_index[j]
                # Mask attention loss.
                if reset_attention_mask:
Neel Kant's avatar
Neel Kant committed
171
                    attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
172
173
                # Reset positions.
                if reset_position_ids:
Neel Kant's avatar
Neel Kant committed
174
                    position_ids[b, (i + 1):] -= (i + 1 - prev_index)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
175
176
                    prev_index = i + 1

177
178
    # Convert attention mask to binary:
    attention_mask = (attention_mask < 0.5)
Mohammad's avatar
Mohammad committed
179

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
180
    return attention_mask, loss_mask, position_ids
181
182