"src/targets/vscode:/vscode.git/clone" did not exist on "7e7a8a6120e8ccb4088877d0ead4486e5a3a56d5"
pretrain_ict.py 5.01 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pretrain BERT for Inverse Cloze Task"""

import torch
Neel Kant's avatar
Neel Kant committed
19
import torch.distributed as dist
20
21
import torch.nn.functional as F

Neel Kant's avatar
Neel Kant committed
22
23
from megatron import get_args
from megatron import print_rank_0
24
from megatron import get_timers
25
from megatron import mpu
26
from megatron.data.dataset_utils import build_train_valid_test_datasets
27
from megatron.training import pretrain
28
from megatron.utils import reduce_losses
Neel Kant's avatar
Neel Kant committed
29
30
from megatron.model.realm_model import general_ict_model_provider
from megatron.data.realm_dataset_utils import get_ict_batch
31
32


Neel Kant's avatar
Neel Kant committed
33
def pretrain_ict_model_provider():
34
    return general_ict_model_provider(False, False)
35
36


mohammad's avatar
mohammad committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def get_group_world_size_rank():

    group = mpu.get_data_parallel_group()
    rank = torch.distributed.get_rank(group=group)
    world_size = torch.distributed.get_world_size(group=group)

    return group, rank, world_size


class AllgatherFromDataParallelRegion(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input_):
        assert input_.dim() == 2
        group, rank, world_size = get_group_world_size_rank()

        tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
        tensor_list[rank] = input_
        torch.distributed.all_gather(tensor_list, input_, group=group)

        output = torch.cat(tensor_list, dim=0).contiguous()

        return output


    @staticmethod
    def backward(ctx, grad_output):
        group, rank, world_size = get_group_world_size_rank()

66
67
68
        assert grad_output.shape[0] % world_size == 0
        dim_size = grad_output.shape[0] // world_size
        output_list = torch.split(grad_output, dim_size, dim=0)
mohammad's avatar
mohammad committed
69

70
71
        # get chunk from this rank
        output = output_list[rank].contiguous()
mohammad's avatar
mohammad committed
72
73
74
        return output


75
def forward_step(data_iterator, model):
76
    """Forward step."""
Neel Kant's avatar
Neel Kant committed
77
    args = get_args()
78
    timers = get_timers()
79
80
81

    # Get the batch.
    timers('batch generator').start()
82
    query_tokens, query_pad_mask, \
Neel Kant's avatar
Neel Kant committed
83
    block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator)
84
85
    timers('batch generator').stop()

86

87
    # Forward model.
Neel Kant's avatar
Neel Kant committed
88
    query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
89
90
    local_batch_size = query_logits.shape[0]
    global_batch_size = dist.get_world_size() * local_batch_size  # recall we assert that model_parallel_size == 1
Neel Kant's avatar
Neel Kant committed
91

92
93
    all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
    all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits)
Neel Kant's avatar
Neel Kant committed
94

95
    # scores are inner products between query and block embeddings
Neel Kant's avatar
Neel Kant committed
96
    retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float())
97
    softmaxed = F.softmax(retrieval_scores, dim=1)
98
    sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True)
99

100
    def topk_accuracy(k):
Neel Kant's avatar
Neel Kant committed
101
102
        return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) for i in range(global_batch_size)]) / global_batch_size])

103
    topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies]
Neel Kant's avatar
Neel Kant committed
104
    retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
105
106
107
108
109
    reduced_losses = reduce_losses([retrieval_loss, *topk_accs])

    # create stats_dict with retrieval loss and all specified top-k accuracies
    topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, reduced_losses[1:])}
    stats_dict = dict(retrieval_loss=reduced_losses[0], **topk_acc_dict)
110

111
    return retrieval_loss, stats_dict
112
113


Neel Kant's avatar
Neel Kant committed
114
115
def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid and test datasets."""
116
    args = get_args()
Neel Kant's avatar
Neel Kant committed
117
    print_rank_0('> building train, validation, and test datasets '
Neel Kant's avatar
Neel Kant committed
118
                 'for BERT ICT...')
119

Neel Kant's avatar
Neel Kant committed
120
121
122
123
124
125
126
127
128
129
    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        data_prefix=args.data_path,
        data_impl=args.data_impl,
        splits_string=args.split,
        train_valid_test_num_samples=train_val_test_num_samples,
        max_seq_length=args.seq_length,
        masked_lm_prob=args.mask_prob,
        short_seq_prob=args.short_seq_prob,
        seed=args.seed,
        skip_warmup=(not args.mmap_warmup),
130
        dataset_type='ict')
Neel Kant's avatar
Neel Kant committed
131
    print_rank_0("> finished creating BERT ICT datasets ...")
132

Neel Kant's avatar
Neel Kant committed
133
    return train_ds, valid_ds, test_ds
134
135
136


if __name__ == "__main__":
Neel Kant's avatar
Neel Kant committed
137
    pretrain(train_valid_test_datasets_provider, pretrain_ict_model_provider, forward_step,
138
             args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})