pretrain_ict.py 5.24 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pretrain BERT for Inverse Cloze Task"""

import torch
Neel Kant's avatar
Neel Kant committed
19
import torch.distributed as dist
20
21
import torch.nn.functional as F

Neel Kant's avatar
Neel Kant committed
22
23
from megatron import get_args
from megatron import print_rank_0
24
from megatron import get_timers
25
from megatron import mpu
26
from megatron.data.dataset_utils import build_train_valid_test_datasets
27
from megatron.training import pretrain
28
from megatron.utils import average_losses_across_data_parallel_group
Neel Kant's avatar
Neel Kant committed
29
30
from megatron.model.realm_model import general_ict_model_provider
from megatron.data.realm_dataset_utils import get_ict_batch
31
32


Neel Kant's avatar
Neel Kant committed
33
def pretrain_ict_model_provider():
34
    args = get_args()
35
    assert args.pipeline_model_parallel_size == 1, 'pipeline_model_parallel_size must be 1!'
36
    return general_ict_model_provider(False, False)
37
38


mohammad's avatar
mohammad committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def get_group_world_size_rank():

    group = mpu.get_data_parallel_group()
    rank = torch.distributed.get_rank(group=group)
    world_size = torch.distributed.get_world_size(group=group)

    return group, rank, world_size


class AllgatherFromDataParallelRegion(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input_):
        assert input_.dim() == 2
        group, rank, world_size = get_group_world_size_rank()

        tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
        tensor_list[rank] = input_
        torch.distributed.all_gather(tensor_list, input_, group=group)

        output = torch.cat(tensor_list, dim=0).contiguous()

        return output


    @staticmethod
    def backward(ctx, grad_output):
        group, rank, world_size = get_group_world_size_rank()

68
69
70
        assert grad_output.shape[0] % world_size == 0
        dim_size = grad_output.shape[0] // world_size
        output_list = torch.split(grad_output, dim_size, dim=0)
mohammad's avatar
mohammad committed
71

72
73
        # get chunk from this rank
        output = output_list[rank].contiguous()
mohammad's avatar
mohammad committed
74
75
76
        return output


77
def forward_step(data_iterator, model, input_tensor):
78
    """Forward step."""
Neel Kant's avatar
Neel Kant committed
79
    args = get_args()
80
    timers = get_timers()
81
82
83

    # Get the batch.
    timers('batch generator').start()
84
    query_tokens, query_pad_mask, \
Neel Kant's avatar
Neel Kant committed
85
    block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator)
86
87
    timers('batch generator').stop()

88

89
    # Forward model.
Neel Kant's avatar
Neel Kant committed
90
    query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
91
    local_batch_size = query_logits.shape[0]
92
    global_batch_size = dist.get_world_size() * local_batch_size  # recall we assert that tensor_model_parallel_size == 1
Neel Kant's avatar
Neel Kant committed
93

94
95
    all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
    all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits)
Neel Kant's avatar
Neel Kant committed
96

97
    # scores are inner products between query and block embeddings
Neel Kant's avatar
Neel Kant committed
98
    retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float())
99
    softmaxed = F.softmax(retrieval_scores, dim=1)
100
    sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True)
101

102
    def topk_accuracy(k):
Neel Kant's avatar
Neel Kant committed
103
104
        return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) for i in range(global_batch_size)]) / global_batch_size])

105
    topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies]
Neel Kant's avatar
Neel Kant committed
106
    retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
107
108
    retrieval_loss = retrieval_loss.float()
    averaged_losses = average_losses_across_data_parallel_group([retrieval_loss, *topk_accs])
109
110

    # create stats_dict with retrieval loss and all specified top-k accuracies
111
112
    topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, averaged_losses[1:])}
    stats_dict = dict(retrieval_loss=averaged_losses[0], **topk_acc_dict)
113

114
    return retrieval_loss, stats_dict
115
116


Neel Kant's avatar
Neel Kant committed
117
118
def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid and test datasets."""
119
    args = get_args()
Neel Kant's avatar
Neel Kant committed
120
    print_rank_0('> building train, validation, and test datasets '
Neel Kant's avatar
Neel Kant committed
121
                 'for BERT ICT...')
122

Neel Kant's avatar
Neel Kant committed
123
124
125
126
127
128
129
130
131
132
    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        data_prefix=args.data_path,
        data_impl=args.data_impl,
        splits_string=args.split,
        train_valid_test_num_samples=train_val_test_num_samples,
        max_seq_length=args.seq_length,
        masked_lm_prob=args.mask_prob,
        short_seq_prob=args.short_seq_prob,
        seed=args.seed,
        skip_warmup=(not args.mmap_warmup),
133
        dataset_type='ict')
Neel Kant's avatar
Neel Kant committed
134
    print_rank_0("> finished creating BERT ICT datasets ...")
135

Neel Kant's avatar
Neel Kant committed
136
    return train_ds, valid_ds, test_ds
137
138
139


if __name__ == "__main__":
Neel Kant's avatar
Neel Kant committed
140
    pretrain(train_valid_test_datasets_provider, pretrain_ict_model_provider, forward_step,
141
             args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})