pretrain_ict.py 5.15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pretrain BERT for Inverse Cloze Task"""

import torch
Neel Kant's avatar
Neel Kant committed
19
import torch.distributed as dist
20
21
import torch.nn.functional as F

Neel Kant's avatar
Neel Kant committed
22
23
from megatron import get_args
from megatron import print_rank_0
24
from megatron import get_timers
25
from megatron import mpu
26
from megatron.data.dataset_utils import build_train_valid_test_datasets
27
from megatron.training import pretrain
28
from megatron.utils import average_losses_across_data_parallel_group
Neel Kant's avatar
Neel Kant committed
29
30
from megatron.model.realm_model import general_ict_model_provider
from megatron.data.realm_dataset_utils import get_ict_batch
31
32


Neel Kant's avatar
Neel Kant committed
33
def pretrain_ict_model_provider():
34
    args = get_args()
35
    return general_ict_model_provider(False, False)
36
37


mohammad's avatar
mohammad committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def get_group_world_size_rank():

    group = mpu.get_data_parallel_group()
    rank = torch.distributed.get_rank(group=group)
    world_size = torch.distributed.get_world_size(group=group)

    return group, rank, world_size


class AllgatherFromDataParallelRegion(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input_):
        assert input_.dim() == 2
        group, rank, world_size = get_group_world_size_rank()

        tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
        tensor_list[rank] = input_
        torch.distributed.all_gather(tensor_list, input_, group=group)

        output = torch.cat(tensor_list, dim=0).contiguous()

        return output


    @staticmethod
    def backward(ctx, grad_output):
        group, rank, world_size = get_group_world_size_rank()

67
68
69
        assert grad_output.shape[0] % world_size == 0
        dim_size = grad_output.shape[0] // world_size
        output_list = torch.split(grad_output, dim_size, dim=0)
mohammad's avatar
mohammad committed
70

71
72
        # get chunk from this rank
        output = output_list[rank].contiguous()
mohammad's avatar
mohammad committed
73
74
75
        return output


76
def forward_step(data_iterator, model, input_tensor):
77
    """Forward step."""
Neel Kant's avatar
Neel Kant committed
78
    args = get_args()
79
    timers = get_timers()
80
81

    # Get the batch.
mohammad's avatar
mohammad committed
82
    timers('batch-generator').start()
83
    query_tokens, query_pad_mask, \
Neel Kant's avatar
Neel Kant committed
84
    block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator)
mohammad's avatar
mohammad committed
85
    timers('batch-generator').stop()
86

87

88
    # Forward model.
Neel Kant's avatar
Neel Kant committed
89
    query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
90
91
    micro_batch_size = query_logits.shape[0]
    global_batch_size = dist.get_world_size() * micro_batch_size  # recall we assert that tensor_model_parallel_size == 1
Neel Kant's avatar
Neel Kant committed
92

93
94
    all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
    all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits)
Neel Kant's avatar
Neel Kant committed
95

96
    # scores are inner products between query and block embeddings
Neel Kant's avatar
Neel Kant committed
97
    retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float())
98
    softmaxed = F.softmax(retrieval_scores, dim=1)
99
    sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True)
100

101
    def topk_accuracy(k):
Neel Kant's avatar
Neel Kant committed
102
103
        return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) for i in range(global_batch_size)]) / global_batch_size])

104
    topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies]
Neel Kant's avatar
Neel Kant committed
105
    retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
106
107
    retrieval_loss = retrieval_loss.float()
    averaged_losses = average_losses_across_data_parallel_group([retrieval_loss, *topk_accs])
108
109

    # create stats_dict with retrieval loss and all specified top-k accuracies
110
111
    topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, averaged_losses[1:])}
    stats_dict = dict(retrieval_loss=averaged_losses[0], **topk_acc_dict)
112

113
    return retrieval_loss, stats_dict
114
115


Neel Kant's avatar
Neel Kant committed
116
117
def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid and test datasets."""
118
    args = get_args()
Neel Kant's avatar
Neel Kant committed
119
    print_rank_0('> building train, validation, and test datasets '
Neel Kant's avatar
Neel Kant committed
120
                 'for BERT ICT...')
121

Neel Kant's avatar
Neel Kant committed
122
123
124
125
126
127
128
129
130
131
    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        data_prefix=args.data_path,
        data_impl=args.data_impl,
        splits_string=args.split,
        train_valid_test_num_samples=train_val_test_num_samples,
        max_seq_length=args.seq_length,
        masked_lm_prob=args.mask_prob,
        short_seq_prob=args.short_seq_prob,
        seed=args.seed,
        skip_warmup=(not args.mmap_warmup),
132
        dataset_type='ict')
Neel Kant's avatar
Neel Kant committed
133
    print_rank_0("> finished creating BERT ICT datasets ...")
134

Neel Kant's avatar
Neel Kant committed
135
    return train_ds, valid_ds, test_ds
136
137
138


if __name__ == "__main__":
Neel Kant's avatar
Neel Kant committed
139
    pretrain(train_valid_test_datasets_provider, pretrain_ict_model_provider, forward_step,
140
             args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})