pretrain_ict.py 5.92 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
2
3

"""Pretrain BERT for Inverse Cloze Task"""
4
5

from functools import partial
Mostofa Patwary's avatar
Mostofa Patwary committed
6
import math
7
8

import torch
Neel Kant's avatar
Neel Kant committed
9
import torch.distributed as dist
10
11
import torch.nn.functional as F

xingjinliang's avatar
xingjinliang committed
12
13
14
from megatron.training import get_args
from megatron.training import print_rank_0
from megatron.training import get_timers
15
from megatron.core import mpu
16
from megatron.core.enums import ModelType
xingjinliang's avatar
xingjinliang committed
17
18
19
from megatron.legacy.data.biencoder_dataset_utils import get_ict_batch
from megatron.legacy.data.dataset_utils import build_train_valid_test_datasets
from megatron.legacy.model.biencoder_model import biencoder_model_provider
20
from megatron.training import pretrain
xingjinliang's avatar
xingjinliang committed
21
from megatron.training.utils import average_losses_across_data_parallel_group
22
23


24
def pretrain_ict_model_provider(pre_process=True, post_process=True):
25
    args = get_args()
Mostofa Patwary's avatar
Mostofa Patwary committed
26

27
28
29
30
    model = biencoder_model_provider(
                only_context_model=False,
                only_query_model=False,
                biencoder_shared_query_context_model=\
31
32
                args.biencoder_shared_query_context_model,
                pre_process=pre_process, post_process=post_process)
Mostofa Patwary's avatar
Mostofa Patwary committed
33

Mostofa Patwary's avatar
Mostofa Patwary committed
34
    return model
35

mohammad's avatar
mohammad committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def get_group_world_size_rank():

    group = mpu.get_data_parallel_group()
    rank = torch.distributed.get_rank(group=group)
    world_size = torch.distributed.get_world_size(group=group)

    return group, rank, world_size


class AllgatherFromDataParallelRegion(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input_):
        assert input_.dim() == 2
        group, rank, world_size = get_group_world_size_rank()

        tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
        tensor_list[rank] = input_
        torch.distributed.all_gather(tensor_list, input_, group=group)

        output = torch.cat(tensor_list, dim=0).contiguous()

        return output


    @staticmethod
    def backward(ctx, grad_output):
        group, rank, world_size = get_group_world_size_rank()

65
66
67
        assert grad_output.shape[0] % world_size == 0
        dim_size = grad_output.shape[0] // world_size
        output_list = torch.split(grad_output, dim_size, dim=0)
mohammad's avatar
mohammad committed
68

69
70
        # get chunk from this rank
        output = output_list[rank].contiguous()
mohammad's avatar
mohammad committed
71
72
        return output

73
def loss_func(output_tensor):
Neel Kant's avatar
Neel Kant committed
74
    args = get_args()
75
    query_logits, context_logits = output_tensor
Neel Kant's avatar
Neel Kant committed
76

Mostofa Patwary's avatar
Mostofa Patwary committed
77
78
    micro_batch_size = query_logits.shape[0]
    # recall we assert that tensor_model_parallel_size == 1
79
80
81
    assert mpu.get_tensor_model_parallel_world_size() == 1, \
        "Model parallel size > 1 not supported for ICT"

82
83
    global_batch_size = dist.get_world_size() * micro_batch_size
    all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
84
    all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
Mostofa Patwary's avatar
Mostofa Patwary committed
85
86
87
88
89
90
91
92
93
94
95

    # scores are inner products between query and context embeddings
    retrieval_scores = torch.matmul(all_query_logits,
                        torch.transpose(all_context_logits, 0, 1))
    # scaling the retriever scores
    if args.retriever_score_scaling:
        retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)

    softmax_scores = F.log_softmax(retrieval_scores, dim=1)
    sorted_vals, sorted_indices = torch.topk(softmax_scores,
                                    k=softmax_scores.shape[1], sorted=True)
96

97
    def topk_accuracy(k):
Mostofa Patwary's avatar
Mostofa Patwary committed
98
99
        return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \
            for i in range(global_batch_size)]) / global_batch_size])
Neel Kant's avatar
Neel Kant committed
100

101
    topk_accs = [topk_accuracy(int(k)) for k in args.retriever_report_topk_accuracies]
102

Mostofa Patwary's avatar
Mostofa Patwary committed
103
104
105
106
107
108
    labels = torch.arange(global_batch_size).long().cuda()
    loss = F.nll_loss(softmax_scores, labels, reduction='mean')
    reduced_losses = average_losses_across_data_parallel_group([loss, *topk_accs])

    # Scale the retrieval loss
    loss = loss * mpu.get_data_parallel_world_size()
109

Mostofa Patwary's avatar
Mostofa Patwary committed
110
111
    # create stats_dict with retrieval loss and all specified top-k accuracies
    topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
112
                        zip(args.retriever_report_topk_accuracies, reduced_losses[1:])}
Mostofa Patwary's avatar
Mostofa Patwary committed
113
114
    stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict)
    return loss, stats_dict
115
116


117
118
119
120
121
122
123

def forward_step(data_iterator, model):
    """Forward step."""
    args = get_args()
    timers = get_timers()

    # Get the batch.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
124
    timers('batch-generator', log_level=2).start()
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    query_tokens, query_mask, \
    context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
    timers('batch-generator').stop()

    # Query and Context Types
    query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
    context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)

    # Forward model.
    output_tensor = model(query_tokens, query_mask, query_types, context_tokens,
                        context_mask, context_types)

    return output_tensor, partial(loss_func)

Neel Kant's avatar
Neel Kant committed
139
140
def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid and test datasets."""
141
    args = get_args()
Neel Kant's avatar
Neel Kant committed
142
    print_rank_0('> building train, validation, and test datasets '
Neel Kant's avatar
Neel Kant committed
143
                 'for BERT ICT...')
144

Neel Kant's avatar
Neel Kant committed
145
146
147
148
149
150
151
152
    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        data_prefix=args.data_path,
        splits_string=args.split,
        train_valid_test_num_samples=train_val_test_num_samples,
        max_seq_length=args.seq_length,
        masked_lm_prob=args.mask_prob,
        short_seq_prob=args.short_seq_prob,
        seed=args.seed,
Mostofa Patwary's avatar
Mostofa Patwary committed
153
        binary_head=False,
154
        dataset_type='ict')
Neel Kant's avatar
Neel Kant committed
155
    print_rank_0("> finished creating BERT ICT datasets ...")
156

Neel Kant's avatar
Neel Kant committed
157
    return train_ds, valid_ds, test_ds
158
159
160


if __name__ == "__main__":
xingjinliang's avatar
xingjinliang committed
161
    print_rank_0("WARNING : This script is DEPRECATED. Will be removed in mcore release 0.9")
Mostofa Patwary's avatar
Mostofa Patwary committed
162
163
    pretrain(train_valid_test_datasets_provider,
             pretrain_ict_model_provider,
164
             ModelType.encoder_or_decoder,
Mostofa Patwary's avatar
Mostofa Patwary committed
165
             forward_step,
166
             args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})