pretrain_ict.py 6.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pretrain BERT for Inverse Cloze Task"""
Mostofa Patwary's avatar
Mostofa Patwary committed
17
18
import sys
import math
19
20

import torch
Neel Kant's avatar
Neel Kant committed
21
import torch.distributed as dist
22
23
import torch.nn.functional as F

Neel Kant's avatar
Neel Kant committed
24
25
from megatron import get_args
from megatron import print_rank_0
26
from megatron import get_timers
27
from megatron import mpu
28
from megatron.data.dataset_utils import build_train_valid_test_datasets
29
from megatron.training import pretrain
30
from megatron.utils import average_losses_across_data_parallel_group
Mostofa Patwary's avatar
Mostofa Patwary committed
31
32
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.data.biencoder_dataset_utils import get_ict_batch
33
34


Neel Kant's avatar
Neel Kant committed
35
def pretrain_ict_model_provider():
36
    args = get_args()
Mostofa Patwary's avatar
Mostofa Patwary committed
37
38
39
40
    model = biencoder_model_provider(only_context_model=False,
                                     only_query_model=False,
                                     shared_query_context_model=args.shared_query_context_model)
    return model
41

mohammad's avatar
mohammad committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def get_group_world_size_rank():

    group = mpu.get_data_parallel_group()
    rank = torch.distributed.get_rank(group=group)
    world_size = torch.distributed.get_world_size(group=group)

    return group, rank, world_size


class AllgatherFromDataParallelRegion(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input_):
        assert input_.dim() == 2
        group, rank, world_size = get_group_world_size_rank()

        tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
        tensor_list[rank] = input_
        torch.distributed.all_gather(tensor_list, input_, group=group)

        output = torch.cat(tensor_list, dim=0).contiguous()

        return output


    @staticmethod
    def backward(ctx, grad_output):
        group, rank, world_size = get_group_world_size_rank()

71
72
73
        assert grad_output.shape[0] % world_size == 0
        dim_size = grad_output.shape[0] // world_size
        output_list = torch.split(grad_output, dim_size, dim=0)
mohammad's avatar
mohammad committed
74

75
76
        # get chunk from this rank
        output = output_list[rank].contiguous()
mohammad's avatar
mohammad committed
77
78
        return output

79
def forward_step(data_iterator, model, input_tensor):
80
    """Forward step."""
Neel Kant's avatar
Neel Kant committed
81
    args = get_args()
82
    timers = get_timers()
83
84

    # Get the batch.
mohammad's avatar
mohammad committed
85
    timers('batch-generator').start()
Mostofa Patwary's avatar
Mostofa Patwary committed
86
87
    query_tokens, query_mask, \
    context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
mohammad's avatar
mohammad committed
88
    timers('batch-generator').stop()
89

Mostofa Patwary's avatar
Mostofa Patwary committed
90
91
92
    # Query and Context Types
    query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
    context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
93

Mostofa Patwary's avatar
Mostofa Patwary committed
94
95
96
97
98
99
    #print_rank_0(query_tokens)
    #print_rank_0(context_tokens)
    #print_rank_0(torch.sum(query_types))
    #print_rank_0(torch.sum(query_mask))
    #print_rank_0(torch.sum(context_types))
    #print_rank_0(torch.sum(context_mask))
Neel Kant's avatar
Neel Kant committed
100

Mostofa Patwary's avatar
Mostofa Patwary committed
101
102
103
104
105
106
107
108
    #print_rank_0(params_global_norm(model))
    #print_rank_0(params_grad_norm(model))
    # Forward model.
    query_logits, context_logits = model(query_tokens, query_mask,
                                    query_types, context_tokens,
                                    context_mask, context_types)
    #print_rank_0(query_logits)
    #print_rank_0(context_logits)
Neel Kant's avatar
Neel Kant committed
109

Mostofa Patwary's avatar
Mostofa Patwary committed
110
111
    micro_batch_size = query_logits.shape[0]
    # recall we assert that tensor_model_parallel_size == 1
112
113
114
    global_batch_size = dist.get_world_size() * micro_batch_size
    all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
    all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
Mostofa Patwary's avatar
Mostofa Patwary committed
115
    
116
117
118
    #global_batch_size = micro_batch_size
    #all_query_logits = query_logits
    #all_context_logits = context_logits
Mostofa Patwary's avatar
Mostofa Patwary committed
119
120
121
122
123
124
125
126
127
128
129

    # scores are inner products between query and context embeddings
    retrieval_scores = torch.matmul(all_query_logits,
                        torch.transpose(all_context_logits, 0, 1))
    # scaling the retriever scores
    if args.retriever_score_scaling:
        retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)

    softmax_scores = F.log_softmax(retrieval_scores, dim=1)
    sorted_vals, sorted_indices = torch.topk(softmax_scores,
                                    k=softmax_scores.shape[1], sorted=True)
130

131
    def topk_accuracy(k):
Mostofa Patwary's avatar
Mostofa Patwary committed
132
133
        return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \
            for i in range(global_batch_size)]) / global_batch_size])
Neel Kant's avatar
Neel Kant committed
134

135
136
    topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies]

Mostofa Patwary's avatar
Mostofa Patwary committed
137
138
139
140
141
142
    labels = torch.arange(global_batch_size).long().cuda()
    loss = F.nll_loss(softmax_scores, labels, reduction='mean')
    reduced_losses = average_losses_across_data_parallel_group([loss, *topk_accs])

    # Scale the retrieval loss
    loss = loss * mpu.get_data_parallel_world_size()
143

Mostofa Patwary's avatar
Mostofa Patwary committed
144
145
146
147
148
149
150
151
152
153
154
155
    #retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
    #retrieval_loss = retrieval_loss.float()
    #averaged_losses = average_losses_across_data_parallel_group([retrieval_loss, *topk_accs])

    # create stats_dict with retrieval loss and all specified top-k accuracies
    topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
                        zip(args.report_topk_accuracies, reduced_losses[1:])}
    stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict)
    #print_rank_0(loss)
    #print_rank_0(stats_dict)
    #sys.exit()
    return loss, stats_dict
156
157


Neel Kant's avatar
Neel Kant committed
158
159
def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid and test datasets."""
160
    args = get_args()
Neel Kant's avatar
Neel Kant committed
161
    print_rank_0('> building train, validation, and test datasets '
Neel Kant's avatar
Neel Kant committed
162
                 'for BERT ICT...')
163

Neel Kant's avatar
Neel Kant committed
164
165
166
167
168
169
170
171
172
173
    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        data_prefix=args.data_path,
        data_impl=args.data_impl,
        splits_string=args.split,
        train_valid_test_num_samples=train_val_test_num_samples,
        max_seq_length=args.seq_length,
        masked_lm_prob=args.mask_prob,
        short_seq_prob=args.short_seq_prob,
        seed=args.seed,
        skip_warmup=(not args.mmap_warmup),
174
        dataset_type='ict')
Neel Kant's avatar
Neel Kant committed
175
    print_rank_0("> finished creating BERT ICT datasets ...")
176

Neel Kant's avatar
Neel Kant committed
177
    return train_ds, valid_ds, test_ds
178
179
180


if __name__ == "__main__":
Mostofa Patwary's avatar
Mostofa Patwary committed
181
182
183
    pretrain(train_valid_test_datasets_provider,
             pretrain_ict_model_provider,
             forward_step,
184
             args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})