utils.py 7.14 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Raul Puri's avatar
Raul Puri committed
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
"""General utilities."""
Raul Puri's avatar
Raul Puri committed
17

18
import sys
19

20
import torch
21
from torch.nn.parallel import DistributedDataParallel as torchDDP
22

mohammad's avatar
mohammad committed
23
24
25
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C

Neel Kant's avatar
Neel Kant committed
26
27
from megatron import get_args
from megatron import print_rank_0
28
from megatron import get_adlr_autoresume
29
from megatron import mpu
mohammad's avatar
mohammad committed
30
31
32
33
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate


34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def unwrap_model(model, module_instances=(torchDDP)):
    return_list = True
    if not isinstance(model, list):
        model = [model]
        return_list = False
    unwrapped_model = []
    for model_module in model:
        while isinstance(model_module, module_instances):
            model_module = model_module.module
        unwrapped_model.append(model_module)
    if not return_list:
        return unwrapped_model[0]
    return unwrapped_model


mohammad's avatar
mohammad committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def calc_params_l2_norm(model):
    """Calculate l2 norm of parameters """
    # Remove duplicate params.
    params_data = []
    for param in model.parameters():
        is_not_shared = param_is_not_shared(param)
        is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
        if is_not_shared and is_not_tp_duplicate:
            params_data.append(param.data)
    # Calculate norm
    dummy_overflow_buf = torch.cuda.IntTensor([0])
    norm, _ = multi_tensor_applier(
        amp_C.multi_tensor_l2norm,
        dummy_overflow_buf,
        [params_data],
        False # no per-parameter norm
    )
    norm_2 = norm * norm
    # Sum across all model-parallel GPUs.
    torch.distributed.all_reduce(norm_2,
                                 op=torch.distributed.ReduceOp.SUM,
                                 group=mpu.get_model_parallel_group())
    return norm_2.item() ** 0.5
72

Mohammad's avatar
Mohammad committed
73

74
def average_losses_across_data_parallel_group(losses):
Mohammad's avatar
Mohammad committed
75
    """Reduce a tensor of losses across all GPUs."""
76
    averaged_losses = torch.cat(
Mohammad's avatar
Mohammad committed
77
        [loss.clone().detach().view(1) for loss in losses])
78
79
80
81
    torch.distributed.all_reduce(averaged_losses,
                                 group=mpu.get_data_parallel_group())
    averaged_losses = averaged_losses / \
        torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
Mohammad's avatar
Mohammad committed
82

83
    return averaged_losses
Mohammad's avatar
Mohammad committed
84
85


86
87
88
89
90
91
92
93
def report_memory(name):
    """Simple GPU memory report."""
    mega_bytes = 1024.0 * 1024.0
    string = name + ' memory (MB)'
    string += ' | allocated: {}'.format(
        torch.cuda.memory_allocated() / mega_bytes)
    string += ' | max allocated: {}'.format(
        torch.cuda.max_memory_allocated() / mega_bytes)
94
95
    string += ' | reserved: {}'.format(
        torch.cuda.memory_reserved() / mega_bytes)
96
97
    string += ' | max reserved: {}'.format(
        torch.cuda.max_memory_reserved() / mega_bytes)
98
    if mpu.get_data_parallel_rank() == 0:
99
100
        print("[Rank {}] {}".format(torch.distributed.get_rank(), string),
              flush=True)
101
102
103
104
105
106


def print_params_min_max_norm(optimizer, iteration):
    """Print min, max, and norm of all parameters."""
    index = 0
    rank = torch.distributed.get_rank()
107
    string = 'iteration, rank, index, tensor-model-parallel, min, max, norm\n'
108
    optimizer_ = optimizer.optimizer
109
110
111
112
113
    for param_group in optimizer_.param_groups:
        for param in param_group['params']:
            index += 1
            min_ = param.data.min()
            max_ = param.data.max()
114
            norm = torch.linalg.norm(param.data)
115
            string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
116
                iteration, rank, index, int(param.tensor_model_parallel))
117
118
119
120
            string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
    print(string, flush=True)


121
122
123
def check_adlr_autoresume_termination(iteration, model,
                                      optimizer, lr_scheduler):
    """Check for autoresume signal and exit if it is received."""
124
125
    from megatron.checkpointing import save_checkpoint

126
127
    args = get_args()
    autoresume = get_adlr_autoresume()
Mohammad's avatar
Mohammad committed
128
    # Add barrier to ensure consistnecy.
129
    torch.distributed.barrier()
130
    if autoresume.termination_requested():
Mohammad's avatar
Mohammad committed
131
        if args.save:
132
            save_checkpoint(iteration, model, optimizer, lr_scheduler)
Mohammad's avatar
Mohammad committed
133
134
        print_rank_0(">>> autoresume termination request found!")
        if torch.distributed.get_rank() == 0:
135
            autoresume.request_resume()
Mohammad's avatar
Mohammad committed
136
        print_rank_0(">>> training terminated. Returning")
137
138
139
        sys.exit(0)


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
140
141
142
143
def get_ltor_masks_and_position_ids(data,
                                    eod_token,
                                    reset_position_ids,
                                    reset_attention_mask,
144
                                    eod_mask_loss):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
145
146
147
    """Build masks and position id for left to right model."""

    # Extract batch size and sequence length.
148
    micro_batch_size, seq_length = data.size()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
149
150
151

    # Attention mask (lower triangular).
    if reset_attention_mask:
152
        att_mask_batch = micro_batch_size
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    else:
        att_mask_batch = 1
    attention_mask = torch.tril(torch.ones(
        (att_mask_batch, seq_length, seq_length), device=data.device)).view(
            att_mask_batch, 1, seq_length, seq_length)

    # Loss mask.
    loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
    if eod_mask_loss:
        loss_mask[data == eod_token] = 0.0

    # Position ids.
    position_ids = torch.arange(seq_length, dtype=torch.long,
                                device=data.device)
    position_ids = position_ids.unsqueeze(0).expand_as(data)
    # We need to clone as the ids will be modifed based on batch index.
    if reset_position_ids:
        position_ids = position_ids.clone()

    if reset_position_ids or reset_attention_mask:
        # Loop through the batches:
174
        for b in range(micro_batch_size):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
175
176
177
178
179
180
181
182
183
184
185
186
187

            # Find indecies where EOD token is.
            eod_index = position_ids[b, data[b] == eod_token]
            # Detach indecies from positions if going to modify positions.
            if reset_position_ids:
                eod_index = eod_index.clone()

            # Loop through EOD indecies:
            prev_index = 0
            for j in range(eod_index.size()[0]):
                i = eod_index[j]
                # Mask attention loss.
                if reset_attention_mask:
Neel Kant's avatar
Neel Kant committed
188
                    attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
189
190
                # Reset positions.
                if reset_position_ids:
Neel Kant's avatar
Neel Kant committed
191
                    position_ids[b, (i + 1):] -= (i + 1 - prev_index)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
192
193
                    prev_index = i + 1

194
195
    # Convert attention mask to binary:
    attention_mask = (attention_mask < 0.5)
Mohammad's avatar
Mohammad committed
196

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
197
    return attention_mask, loss_mask, position_ids
198
199