utils.py 7.43 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Raul Puri's avatar
Raul Puri committed
2

3
"""General utilities."""
Raul Puri's avatar
Raul Puri committed
4

5
import sys
6

7
import torch
8
from torch.nn.parallel import DistributedDataParallel as torchDDP
9

mohammad's avatar
mohammad committed
10
11
12
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C

Neel Kant's avatar
Neel Kant committed
13
from megatron import get_args
14
from megatron import get_adlr_autoresume
15
from megatron import mpu
mohammad's avatar
mohammad committed
16
17
18
19
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate


20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def unwrap_model(model, module_instances=(torchDDP)):
    return_list = True
    if not isinstance(model, list):
        model = [model]
        return_list = False
    unwrapped_model = []
    for model_module in model:
        while isinstance(model_module, module_instances):
            model_module = model_module.module
        unwrapped_model.append(model_module)
    if not return_list:
        return unwrapped_model[0]
    return unwrapped_model


mohammad's avatar
mohammad committed
35
36
def calc_params_l2_norm(model):
    """Calculate l2 norm of parameters """
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
37
38
39
    args = get_args()
    if not isinstance(model, list):
        model = [model]
mohammad's avatar
mohammad committed
40
41
    # Remove duplicate params.
    params_data = []
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
42
43
44
45
46
47
48
49
50
    for model_ in model:
        for param in model_.parameters():
            is_not_shared = param_is_not_shared(param)
            is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
            if is_not_shared and is_not_tp_duplicate:
                if args.bf16:
                    params_data.append(param.data.float())
                else:
                    params_data.append(param.data)
mohammad's avatar
mohammad committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    # Calculate norm
    dummy_overflow_buf = torch.cuda.IntTensor([0])
    norm, _ = multi_tensor_applier(
        amp_C.multi_tensor_l2norm,
        dummy_overflow_buf,
        [params_data],
        False # no per-parameter norm
    )
    norm_2 = norm * norm
    # Sum across all model-parallel GPUs.
    torch.distributed.all_reduce(norm_2,
                                 op=torch.distributed.ReduceOp.SUM,
                                 group=mpu.get_model_parallel_group())
    return norm_2.item() ** 0.5
65

Mohammad's avatar
Mohammad committed
66

67
def average_losses_across_data_parallel_group(losses):
Mohammad's avatar
Mohammad committed
68
    """Reduce a tensor of losses across all GPUs."""
69
    averaged_losses = torch.cat(
Mohammad's avatar
Mohammad committed
70
        [loss.clone().detach().view(1) for loss in losses])
71
72
73
74
    torch.distributed.all_reduce(averaged_losses,
                                 group=mpu.get_data_parallel_group())
    averaged_losses = averaged_losses / \
        torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
Mohammad's avatar
Mohammad committed
75

76
    return averaged_losses
Mohammad's avatar
Mohammad committed
77
78


79
80
81
82
83
84
85
86
def report_memory(name):
    """Simple GPU memory report."""
    mega_bytes = 1024.0 * 1024.0
    string = name + ' memory (MB)'
    string += ' | allocated: {}'.format(
        torch.cuda.memory_allocated() / mega_bytes)
    string += ' | max allocated: {}'.format(
        torch.cuda.max_memory_allocated() / mega_bytes)
87
88
    string += ' | reserved: {}'.format(
        torch.cuda.memory_reserved() / mega_bytes)
89
90
    string += ' | max reserved: {}'.format(
        torch.cuda.max_memory_reserved() / mega_bytes)
91
    if mpu.get_data_parallel_rank() == 0:
92
93
        print("[Rank {}] {}".format(torch.distributed.get_rank(), string),
              flush=True)
94
95
96
97
98
99


def print_params_min_max_norm(optimizer, iteration):
    """Print min, max, and norm of all parameters."""
    index = 0
    rank = torch.distributed.get_rank()
100
    string = 'iteration, rank, index, tensor-model-parallel, min, max, norm\n'
101
    optimizer_ = optimizer.optimizer
102
103
104
105
106
    for param_group in optimizer_.param_groups:
        for param in param_group['params']:
            index += 1
            min_ = param.data.min()
            max_ = param.data.max()
107
            norm = torch.linalg.norm(param.data)
108
            string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
109
                iteration, rank, index, int(param.tensor_model_parallel))
110
111
112
113
            string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
    print(string, flush=True)


114
def check_adlr_autoresume_termination(iteration, model,
115
                                      optimizer, opt_param_scheduler):
116
    """Check for autoresume signal and exit if it is received."""
117
118
    from megatron.checkpointing import save_checkpoint

119
120
    args = get_args()
    autoresume = get_adlr_autoresume()
Mohammad's avatar
Mohammad committed
121
    # Add barrier to ensure consistnecy.
122
    torch.distributed.barrier()
123
    if autoresume.termination_requested():
Mohammad's avatar
Mohammad committed
124
        if args.save:
125
            save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
Mohammad's avatar
Mohammad committed
126
127
        print_rank_0(">>> autoresume termination request found!")
        if torch.distributed.get_rank() == 0:
128
            autoresume.request_resume()
Mohammad's avatar
Mohammad committed
129
        print_rank_0(">>> training terminated. Returning")
130
131
132
        sys.exit(0)


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
133
134
135
136
def get_ltor_masks_and_position_ids(data,
                                    eod_token,
                                    reset_position_ids,
                                    reset_attention_mask,
137
                                    eod_mask_loss):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
138
139
140
    """Build masks and position id for left to right model."""

    # Extract batch size and sequence length.
141
    micro_batch_size, seq_length = data.size()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
142
143
144

    # Attention mask (lower triangular).
    if reset_attention_mask:
145
        att_mask_batch = micro_batch_size
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    else:
        att_mask_batch = 1
    attention_mask = torch.tril(torch.ones(
        (att_mask_batch, seq_length, seq_length), device=data.device)).view(
            att_mask_batch, 1, seq_length, seq_length)

    # Loss mask.
    loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
    if eod_mask_loss:
        loss_mask[data == eod_token] = 0.0

    # Position ids.
    position_ids = torch.arange(seq_length, dtype=torch.long,
                                device=data.device)
    position_ids = position_ids.unsqueeze(0).expand_as(data)
    # We need to clone as the ids will be modifed based on batch index.
    if reset_position_ids:
        position_ids = position_ids.clone()

    if reset_position_ids or reset_attention_mask:
        # Loop through the batches:
167
        for b in range(micro_batch_size):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
168
169
170
171
172
173
174
175
176
177
178
179
180

            # Find indecies where EOD token is.
            eod_index = position_ids[b, data[b] == eod_token]
            # Detach indecies from positions if going to modify positions.
            if reset_position_ids:
                eod_index = eod_index.clone()

            # Loop through EOD indecies:
            prev_index = 0
            for j in range(eod_index.size()[0]):
                i = eod_index[j]
                # Mask attention loss.
                if reset_attention_mask:
Neel Kant's avatar
Neel Kant committed
181
                    attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
182
183
                # Reset positions.
                if reset_position_ids:
Neel Kant's avatar
Neel Kant committed
184
                    position_ids[b, (i + 1):] -= (i + 1 - prev_index)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
185
186
                    prev_index = i + 1

187
188
    # Convert attention mask to binary:
    attention_mask = (attention_mask < 0.5)
Mohammad's avatar
Mohammad committed
189

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
190
    return attention_mask, loss_mask, position_ids
191
192


193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def print_rank_0(message):
    """If distributed is initialized, print only on rank 0."""
    if torch.distributed.is_initialized():
        if torch.distributed.get_rank() == 0:
            print(message, flush=True)
    else:
        print(message, flush=True)

def is_last_rank():
    return torch.distributed.get_rank() == (
        torch.distributed.get_world_size() - 1)

def print_rank_last(message):
    """If distributed is initialized, print only on last rank."""
    if torch.distributed.is_initialized():
        if is_last_rank():
            print(message, flush=True)
    else:
        print(message, flush=True)