utils.py 6.8 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Raul Puri's avatar
Raul Puri committed
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
"""General utilities."""
Raul Puri's avatar
Raul Puri committed
17

18
import sys
19

20
import torch
21

Neel Kant's avatar
Neel Kant committed
22
23
from megatron import get_args
from megatron import print_rank_0
24
from megatron import get_adlr_autoresume
25
from megatron import mpu
26
from megatron.checkpointing import save_checkpoint
27
28
from megatron.fp16 import FP16_Optimizer

Mohammad's avatar
Mohammad committed
29

30
def average_losses_across_data_parallel_group(losses):
Mohammad's avatar
Mohammad committed
31
    """Reduce a tensor of losses across all GPUs."""
32
    averaged_losses = torch.cat(
Mohammad's avatar
Mohammad committed
33
        [loss.clone().detach().view(1) for loss in losses])
34
35
36
37
    torch.distributed.all_reduce(averaged_losses,
                                 group=mpu.get_data_parallel_group())
    averaged_losses = averaged_losses / \
        torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
Mohammad's avatar
Mohammad committed
38

39
    return averaged_losses
Mohammad's avatar
Mohammad committed
40
41


42
43
44
45
46
47
48
49
def report_memory(name):
    """Simple GPU memory report."""
    mega_bytes = 1024.0 * 1024.0
    string = name + ' memory (MB)'
    string += ' | allocated: {}'.format(
        torch.cuda.memory_allocated() / mega_bytes)
    string += ' | max allocated: {}'.format(
        torch.cuda.max_memory_allocated() / mega_bytes)
50
51
52
    string += ' | reserved: {}'.format(torch.cuda.memory_reserved() / mega_bytes)
    string += ' | max reserved: {}'.format(
        torch.cuda.max_memory_reserved() / mega_bytes)
53
54
    if mpu.get_data_parallel_rank() == 0:
        print("[Rank {}] {}".format(torch.distributed.get_rank(), string), flush=True)
55
56
57
58
59
60


def print_params_min_max_norm(optimizer, iteration):
    """Print min, max, and norm of all parameters."""
    index = 0
    rank = torch.distributed.get_rank()
61
    string = 'iteration, rank, index, tensor-model-parallel, min, max, norm\n'
62
63
64
65
66
67
68
69
    optimizer_ = optimizer
    if isinstance(optimizer, FP16_Optimizer):
        optimizer_ = optimizer.optimizer
    for param_group in optimizer_.param_groups:
        for param in param_group['params']:
            index += 1
            min_ = param.data.min()
            max_ = param.data.max()
70
            norm = torch.linalg.norm(param.data)
71
            string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
72
                iteration, rank, index, int(param.tensor_model_parallel))
73
74
75
76
            string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
    print(string, flush=True)


77
78
79
80
81
def check_adlr_autoresume_termination(iteration, model,
                                      optimizer, lr_scheduler):
    """Check for autoresume signal and exit if it is received."""
    args = get_args()
    autoresume = get_adlr_autoresume()
Mohammad's avatar
Mohammad committed
82
    # Add barrier to ensure consistnecy.
83
    torch.distributed.barrier()
84
    if autoresume.termination_requested():
Mohammad's avatar
Mohammad committed
85
        if args.save:
86
            save_checkpoint(iteration, model, optimizer, lr_scheduler)
Mohammad's avatar
Mohammad committed
87
88
        print_rank_0(">>> autoresume termination request found!")
        if torch.distributed.get_rank() == 0:
89
            autoresume.request_resume()
Mohammad's avatar
Mohammad committed
90
        print_rank_0(">>> training terminated. Returning")
91
92
93
        sys.exit(0)


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
94
95
96
97
def get_ltor_masks_and_position_ids(data,
                                    eod_token,
                                    reset_position_ids,
                                    reset_attention_mask,
98
                                    eod_mask_loss):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
99
100
101
    """Build masks and position id for left to right model."""

    # Extract batch size and sequence length.
102
    micro_batch_size, seq_length = data.size()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
103
104
105

    # Attention mask (lower triangular).
    if reset_attention_mask:
106
        att_mask_batch = micro_batch_size
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    else:
        att_mask_batch = 1
    attention_mask = torch.tril(torch.ones(
        (att_mask_batch, seq_length, seq_length), device=data.device)).view(
            att_mask_batch, 1, seq_length, seq_length)

    # Loss mask.
    loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
    if eod_mask_loss:
        loss_mask[data == eod_token] = 0.0

    # Position ids.
    position_ids = torch.arange(seq_length, dtype=torch.long,
                                device=data.device)
    position_ids = position_ids.unsqueeze(0).expand_as(data)
    # We need to clone as the ids will be modifed based on batch index.
    if reset_position_ids:
        position_ids = position_ids.clone()

    if reset_position_ids or reset_attention_mask:
        # Loop through the batches:
128
        for b in range(micro_batch_size):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
129
130
131
132
133
134
135
136
137
138
139
140
141

            # Find indecies where EOD token is.
            eod_index = position_ids[b, data[b] == eod_token]
            # Detach indecies from positions if going to modify positions.
            if reset_position_ids:
                eod_index = eod_index.clone()

            # Loop through EOD indecies:
            prev_index = 0
            for j in range(eod_index.size()[0]):
                i = eod_index[j]
                # Mask attention loss.
                if reset_attention_mask:
Neel Kant's avatar
Neel Kant committed
142
                    attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
143
144
                # Reset positions.
                if reset_position_ids:
Neel Kant's avatar
Neel Kant committed
145
                    position_ids[b, (i + 1):] -= (i + 1 - prev_index)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
146
147
                    prev_index = i + 1

148
149
    # Convert attention mask to binary:
    attention_mask = (attention_mask < 0.5)
Mohammad's avatar
Mohammad committed
150

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
151
    return attention_mask, loss_mask, position_ids
152

Mostofa Patwary's avatar
Mostofa Patwary committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def params_grad_norm(model):
    print_rank_0("params_grad_norm")
    norm2 = torch.cuda.FloatTensor([0.0])
    for param in model.parameters():
        if param.grad is None:
            continue
        norm = torch.norm(param.grad.data.float(), 2)
        norm2 += norm * norm
    torch.distributed.all_reduce(norm2)
    norm = norm2 ** 0.5
    return norm.item()


def params_global_norm(model):
    print_rank_0("params_global_norm")
    norm2 = torch.cuda.FloatTensor([0.0])
    for param in model.parameters():
        norm = torch.norm(param.data.float(), 2)
        norm2 += norm * norm
    torch.distributed.all_reduce(norm2)
    norm = norm2 ** 0.5
    return norm.item()

def print_model(model):
    print_rank_0("print-model")
    for name, param in model.named_parameters():
        if param.requires_grad:
            #print("{} {}".format(name, param.data), flush=True)
            print_rank_0("{} {}".format(name, param.data))
            return

def print_grads(model):
    print_rank_0("print-grads")
    for name, param in model.named_parameters():
        if param.grad is None:
            continue
        print_rank_0("{} {}".format(name, param.grad))