pretrain_ict.py 7.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pretrain BERT for Inverse Cloze Task"""

import torch
Neel Kant's avatar
Neel Kant committed
19
import torch.distributed as dist
20
21
import torch.nn.functional as F

Neel Kant's avatar
Neel Kant committed
22
23
from megatron import get_args
from megatron import print_rank_0
24
from megatron import get_timers
25
from megatron import mpu
26
from megatron.data.dataset_utils import build_train_valid_test_datasets
27
from megatron.model import ICTBertModel
28
from megatron.training import pretrain
29
30
from megatron.utils import reduce_losses

Neel Kant's avatar
Neel Kant committed
31
num_batches = 0
32

Neel Kant's avatar
Neel Kant committed
33

34
def general_model_provider(only_query_model=False, only_block_model=False):
35
    """Build the model."""
36
    args = get_args()
Neel Kant's avatar
Neel Kant committed
37
38
39
40
41
    assert args.ict_head_size is not None, \
        "Need to specify --ict-head-size to provide an ICTBertModel"

    assert args.model_parallel_size == 1, \
        "Model parallel size > 1 not supported for ICT"
42
43

    print_rank_0('building ICTBertModel...')
44

45
    # simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
46
    model = ICTBertModel(
47
        ict_head_size=args.ict_head_size,
48
        num_tokentypes=2,
Neel Kant's avatar
Neel Kant committed
49
50
51
        parallel_output=True,
        only_query_model=only_query_model,
        only_block_model=only_block_model)
52
53
54
55

    return model


56
57
58
59
def model_provider():
    return general_model_provider(False, False)


mohammad's avatar
mohammad committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

def get_group_world_size_rank():

    group = mpu.get_data_parallel_group()
    rank = torch.distributed.get_rank(group=group)
    world_size = torch.distributed.get_world_size(group=group)

    return group, rank, world_size


def get_rank_chunk_along_first_dim(tensor):

    group, rank, world_size = get_group_world_size_rank()

    assert tensor.shape[0] % world_size == 0
    dim_size = tensor.shape[0] // world_size
    output_list = torch.split(tensor, dim_size, dim=0)
    
    output = output_list[rank].contiguous()
    return output


class AllgatherFromDataParallelRegion(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input_):

        assert input_.dim() == 2
        group, rank, world_size = get_group_world_size_rank()

        tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
        tensor_list[rank] = input_
        torch.distributed.all_gather(tensor_list, input_, group=group)

        output = torch.cat(tensor_list, dim=0).contiguous()

        return output


    @staticmethod
    def backward(ctx, grad_output):

        return get_rank_chunk_along_first_dim(grad_output)


class AllReduceFromDataParallelRegion(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input_):

        assert input_.dim() == 2
        group, rank, world_size = get_group_world_size_rank()

        tensor_list = [torch.zero_like(input_) for _ in range(world_size)]
        tensor_list[rank] = input_
        output = torch.cat(tensor_list, dim=0).contiguous() 
        torch.distributed.all_reduce(output, group=group)

        return output


    @staticmethod
    def backward(ctx, grad_output):

        return get_rank_chunk_along_first_dim(grad_output)


127
def get_batch(data_iterator):
128
    # Items and their type.
129
130
    keys = ['query_tokens', 'query_pad_mask',
            'block_tokens', 'block_pad_mask', 'block_data']
131
132
133
    datatype = torch.int64

    # Broadcast data.
134
    if data_iterator is None:
135
        data = None
136
137
    else:
        data = next(data_iterator)
138
139
140
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
Neel Kant's avatar
Neel Kant committed
141
142
143
144
    query_tokens = data_b['query_tokens'].long()
    query_pad_mask = data_b['query_pad_mask'].long()
    block_tokens = data_b['block_tokens'].long()
    block_pad_mask = data_b['block_pad_mask'].long()
145
    block_indices = data_b['block_data'].long()
146

147
148
    return query_tokens, query_pad_mask,\
           block_tokens, block_pad_mask, block_indices
149
150


151
def forward_step(data_iterator, model):
152
    """Forward step."""
Neel Kant's avatar
Neel Kant committed
153
    args = get_args()
154
    timers = get_timers()
155
156
157

    # Get the batch.
    timers('batch generator').start()
158
159
    query_tokens, query_pad_mask, \
    block_tokens, block_pad_mask, block_indices = get_batch(data_iterator)
160
161
162
    timers('batch generator').stop()

    # Forward model.
Neel Kant's avatar
Neel Kant committed
163
164
    query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)

mohammad's avatar
mohammad committed
165
    IMPLEMENTATION = 'original'
Neel Kant's avatar
Neel Kant committed
166

mohammad's avatar
mohammad committed
167
168
169
170
171
172
173
174
    if IMPLEMENTATION == 'original':
        data_parallel_size = dist.get_world_size() / args.model_parallel_size
        batch_size = query_logits.shape[0]
        global_batch_size = int(batch_size * data_parallel_size)
        
        all_logits_shape = (int(global_batch_size), int(query_logits.shape[1]))
        all_query_logits = torch.cuda.FloatTensor(*all_logits_shape).type(query_logits.dtype).fill_(0.0)
        all_block_logits = all_query_logits.clone()
Neel Kant's avatar
Neel Kant committed
175

mohammad's avatar
mohammad committed
176
177
178
        # record this processes' data
        all_query_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = query_logits
        all_block_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = block_logits
Neel Kant's avatar
Neel Kant committed
179

mohammad's avatar
mohammad committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        # merge data from all processes
        dist.all_reduce(all_query_logits)
        dist.all_reduce(all_block_logits)

    elif IMPLEMENTATION == 'allgather':
        all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
        all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits)

    elif IMPLEMENTATION == 'allreduce':
        all_query_logits = AllReduceFromDataParallelRegion.apply(query_logits)
        all_block_logits = AllReduceFromDataParallelRegion.apply(block_logits)

    else:
        raise Exception('should not be here.')
Neel Kant's avatar
Neel Kant committed
194

195
    # scores are inner products between query and block embeddings
Neel Kant's avatar
Neel Kant committed
196
    retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float())
197
    softmaxed = F.softmax(retrieval_scores, dim=1)
198
    sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True)
199

200
    def topk_accuracy(k):
Neel Kant's avatar
Neel Kant committed
201
202
        return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) for i in range(global_batch_size)]) / global_batch_size])

203
    topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies]
Neel Kant's avatar
Neel Kant committed
204
    retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
205
206
207
208
209
    reduced_losses = reduce_losses([retrieval_loss, *topk_accs])

    # create stats_dict with retrieval loss and all specified top-k accuracies
    topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, reduced_losses[1:])}
    stats_dict = dict(retrieval_loss=reduced_losses[0], **topk_acc_dict)
210

211
    return retrieval_loss, stats_dict
212
213


Neel Kant's avatar
Neel Kant committed
214
215
def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid and test datasets."""
216
    args = get_args()
Neel Kant's avatar
Neel Kant committed
217
    print_rank_0('> building train, validation, and test datasets '
Neel Kant's avatar
Neel Kant committed
218
                 'for BERT ICT...')
219

Neel Kant's avatar
Neel Kant committed
220
221
222
223
224
225
226
227
228
229
    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
        data_prefix=args.data_path,
        data_impl=args.data_impl,
        splits_string=args.split,
        train_valid_test_num_samples=train_val_test_num_samples,
        max_seq_length=args.seq_length,
        masked_lm_prob=args.mask_prob,
        short_seq_prob=args.short_seq_prob,
        seed=args.seed,
        skip_warmup=(not args.mmap_warmup),
230
        dataset_type='ict')
Neel Kant's avatar
Neel Kant committed
231
    print_rank_0("> finished creating BERT ICT datasets ...")
232

Neel Kant's avatar
Neel Kant committed
233
    return train_ds, valid_ds, test_ds
234
235
236


if __name__ == "__main__":
Neel Kant's avatar
Neel Kant committed
237
    pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
238
             args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})