finetune.py 12.4 KB
Newer Older
mpatwary's avatar
mpatwary committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""ORQA finetuning/evaluation."""

from functools import partial
Mostofa Patwary's avatar
Mostofa Patwary committed
19
import sys
mpatwary's avatar
mpatwary committed
20
21
22
23
24

import math
import torch
import torch.nn.functional as F

Mostofa Patwary's avatar
Mostofa Patwary committed
25
26
27
from megatron import get_args, get_timers, get_tokenizer
from megatron import mpu, print_rank_0
from megatron.indexer import IndexBuilder
mpatwary's avatar
mpatwary committed
28
from megatron.model.biencoder_model import biencoder_model_provider
Mostofa Patwary's avatar
Mostofa Patwary committed
29
from megatron.utils import average_losses_across_data_parallel_group
mpatwary's avatar
mpatwary committed
30
31
32
33
from pretrain_ict import get_group_world_size_rank
from tasks.finetune_utils import finetune
from tasks.orqa.supervised.eval_utils import accuracy_func_provider
from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn
Mostofa Patwary's avatar
Mostofa Patwary committed
34
from tasks.orqa.evaluate_utils import ORQAEvaluator
mpatwary's avatar
mpatwary committed
35

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# input_ is a 2D tensor
def check_and_append_tensor_for_gather(group, rank, world_size, input_):

    # gather the size of the first dimension of the tensor from all ranks
    current_length = input_.size()[0]
    first_dim = torch.tensor([[current_length]], 
        device=torch.cuda.current_device())
    input_list = [torch.empty_like(first_dim) for _ in range(world_size)]
    input_list[rank].copy_(first_dim)
    torch.distributed.all_gather(input_list, first_dim, group=group)
    all_input_list = torch.cat(input_list, dim=0).contiguous()
    max_length = torch.max(all_input_list)
    min_length = torch.min(all_input_list)

    #if rank == 0:
    #    print("rank {} all pad neg_context_tokens 0 {}".format(rank, input_[0]), flush=True)
    #    print("rank {} all pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True)

    if max_length > current_length:
        #print("rank {} before pad neg_context_tokens current_length-1 {}".format(rank, input_[current_length-1]), flush=True)
        #torch.set_printoptions(profile="full")
        
        #input_ = torch.nn.functional.pad(input=input_, 
        #    pad=(0, 0, 0, max_length - current_length))
        padding=tuple([0] * (input_.dim() * 2 - 1)) + \
            tuple([max_length - current_length])
        input_ = F.pad(input=input_, pad=padding)

        #print("rank {} after pad neg_context_tokens current_length-1 {}".format(rank, input_[current_length-1]), flush=True)
        #print("rank {} after pad neg_context_tokens current_length {}".format(rank, input_[current_length]), flush=True)
        #print("rank {} after pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True)

    #if rank == 0:
    #    print("rank {} all pad neg_context_tokens 0 {}".format(rank, input_[0]), flush=True)
    #    print("rank {} all pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True)
        
    return input_

Mostofa Patwary's avatar
Mostofa Patwary committed
74
def orqa(Dataset):
mpatwary's avatar
mpatwary committed
75
76
77
78
79
80
81
82
83
84
85
86
87

    def cross_entropy_forward_step(batch, model):
        """Simple forward step with cross-entropy loss."""
        timers = get_timers()
        tokenizer = get_tokenizer()

        # Get the batch.
        timers('batch generator').start()
        try:
            batch_ = next(batch)
        except BaseException:
            batch_ = batch

Mostofa Patwary's avatar
Mostofa Patwary committed
88
89
        group, rank, world_size = get_group_world_size_rank()

mpatwary's avatar
mpatwary committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        query_tokens, query_mask, query_types, query_pad_mask, \
        context_tokens, context_mask, context_types, context_pad_mask, \
        neg_context_tokens, neg_context_mask, neg_context_types, \
        reference = process_batch(batch_)

        timers('batch generator').stop()
        local_batch_size = query_tokens.shape[0]

        # Text representation of query and context
        query_list, context_list = [], []
        for i in range(local_batch_size):
            query_list.append(tokenizer.decode(query_tokens[i].tolist()))
            context_list.append(tokenizer.decode(context_tokens[i].tolist()))

104
105
106
107
108
109
110
111
112
113
114
115
116
117
        #if rank == 5:
        #    print("rank {} before query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {} neg_context_tokens {} neg_context_mask {} neg_context_types {}".format(rank, query_tokens.size(), query_mask.size(), 
        #        query_types.size(), context_tokens.size(), context_mask.size(), context_types.size(), neg_context_tokens.size(), neg_context_mask.size(), neg_context_types.size()), flush=True)
 
        if neg_context_tokens is not None: # and neg_context_tokens.size()[0] > local_batch_size:
            neg_context_tokens = check_and_append_tensor_for_gather(group, rank, world_size, neg_context_tokens)
            neg_context_mask = check_and_append_tensor_for_gather(group, rank, world_size, neg_context_mask)
            neg_context_types = check_and_append_tensor_for_gather(group, rank, world_size, neg_context_types)
            #exit()

        #if rank == 5:
        #    print("rank {} middle query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {} neg_context_tokens {} neg_context_mask {} neg_context_types {}".format(rank, query_tokens.size(), query_mask.size(), 
        #        query_types.size(), context_tokens.size(), context_mask.size(), context_types.size(), neg_context_tokens.size(), neg_context_mask.size(), neg_context_types.size()), flush=True)
 
Mostofa Patwary's avatar
Mostofa Patwary committed
118

mpatwary's avatar
mpatwary committed
119
120
121
122
123
        if neg_context_tokens is not None:
            context_tokens = torch.cat([context_tokens, neg_context_tokens])
            context_mask = torch.cat([context_mask, neg_context_mask])
            context_types = torch.cat([context_types, neg_context_types])

124
125
126
127
        #if rank == 5:
        #    print("rank {} after query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {}".format(rank, query_tokens.size(), query_mask.size(), 
        #        query_types.size(), context_tokens.size(), context_mask.size(), context_types.size()), flush=True)

Mostofa Patwary's avatar
Mostofa Patwary committed
128
        #print("==rank {} query_tokens {} context_tokens {}".format(rank, query_tokens.size(), context_tokens.size()), flush=True)
mpatwary's avatar
mpatwary committed
129
        # Forward model.
Mostofa Patwary's avatar
Mostofa Patwary committed
130
131
        output_tensor = model(query_tokens, query_mask,
                                        query_types, context_tokens,
mpatwary's avatar
mpatwary committed
132
                                        context_mask, context_types)
Mostofa Patwary's avatar
Mostofa Patwary committed
133
        return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens)
mpatwary's avatar
mpatwary committed
134
135


Mostofa Patwary's avatar
Mostofa Patwary committed
136
137
    def cross_entropy_loss_func(query_tokens, context_tokens, output_tensor):
        args = get_args()
mpatwary's avatar
mpatwary committed
138
139
140
141
142
143
144
145
146

        local_batch_size = query_tokens.shape[0]
        group, rank, world_size = get_group_world_size_rank()
        # recall we assert that model_parallel_size == 1
        global_batch_size = world_size * local_batch_size

        query_logits, context_logits = output_tensor

        if world_size > 1:
Mostofa Patwary's avatar
Mostofa Patwary committed
147
            #print("rank {} query_logits {} context_logits {}".format(rank, query_logits.size(), context_logits.size()))
mpatwary's avatar
mpatwary committed
148
149
150
151
            input_ = torch.empty_like(context_logits).copy_(\
                context_logits).detach_()
            tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
            tensor_list[rank].copy_(input_)
Mostofa Patwary's avatar
Mostofa Patwary committed
152
153
            #print_rank_0("At cross_entropy_loss_func")
            #print("rank {} input_ {}".format(rank, input_.size()))
mpatwary's avatar
mpatwary committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
            torch.distributed.all_gather(tensor_list, input_, group=group)

            # Check if all-gather happens in order
            assert tensor_list[rank].sum().item() == \
                context_logits.sum().item()

            # Preserves the gradient
            tensor_list[rank] = context_logits
            all_context_logits = torch.cat(tensor_list, dim=0).contiguous()

            # Query tensors
            input_ = torch.empty_like(query_logits).copy_(\
                query_logits).detach_()
            tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
            tensor_list[rank].copy_(input_)
            torch.distributed.all_gather(tensor_list, input_, group=group)

            # Check if all-gather happens in order
            assert tensor_list[rank].sum().item() == query_logits.sum().item()

            # Preserves the gradient
            tensor_list[rank] = query_logits
            all_query_logits = torch.cat(tensor_list, dim=0).contiguous()
        else:
            all_query_logits = query_logits
            all_context_logits = context_logits

        retrieval_scores = torch.matmul(all_query_logits,
                            torch.transpose(all_context_logits, 0, 1))
        # Scaling the retrieval scores
        if args.retriever_score_scaling:
            retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)

        if args.train_with_neg:
            # if the world size is 3, local batch size is 4, and
            # local context size is 8, what we want is
            # labels = [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19]
            labels = []
            local_context_size = context_tokens.shape[0]
            for i in range(world_size):
                j = i * local_context_size
                labels.extend(list(range(j, j + local_batch_size)))
            labels = torch.LongTensor(labels).cuda()
            assert len(labels) == global_batch_size
        else:
            labels = torch.arange(global_batch_size).long().cuda()

        # Cross-entropy loss.
        softmax_scores = F.log_softmax(retrieval_scores, dim=1)

        loss = F.nll_loss(softmax_scores, labels, reduction='mean')

        max_score, max_idxs = torch.max(softmax_scores, 1)
        correct_predictions_count = (max_idxs == labels).sum().float()

        # Reduce loss for logging.
        reduced_loss = average_losses_across_data_parallel_group([loss, \
            correct_predictions_count])

        # Loss scaling for correct losses in Supervised Retrieval
        loss = loss * mpu.get_data_parallel_world_size()

        return loss, {'lm loss': reduced_loss[0],
                      'correct_prediction_count': reduced_loss[1]}


    def train_valid_datasets_provider():
        """Build train and validation dataset."""
        args = get_args()
        tokenizer = get_tokenizer()

        train_dataset = Dataset('training',
                                args.train_data,
                                tokenizer,
                                args.retriever_seq_length,
                                evaluate=False)
        valid_dataset = Dataset('validation',
                                args.valid_data,
                                tokenizer,
                                args.retriever_seq_length,
                                evaluate=True)
        return train_dataset, valid_dataset

    def model_provider(pre_process=True, post_process=True):
        """Build the model."""
        args = get_args()
        print_rank_0('building retriever model for {} ...'.format(args.task))
Mostofa Patwary's avatar
Mostofa Patwary committed
241

242
        model = biencoder_model_provider(only_context_model=False,
Mostofa Patwary's avatar
Mostofa Patwary committed
243
                    only_query_model=False,
244
245
246
247
                    biencoder_shared_query_context_model=\
                    args.biencoder_shared_query_context_model,
                    pre_process=pre_process, post_process=post_process)

mpatwary's avatar
mpatwary committed
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
        return model

    def single_dataset_provider(datapath):
        args = get_args()
        tokenizer = get_tokenizer()

        name = datapath[0].split('/')[-1].split('.')[0]
        return Dataset(name,
                       datapath,
                       tokenizer,
                       args.retriever_seq_length,
                       evaluate=True)

    def metrics_func_provider():
        """Provide metrics callback function."""
        return accuracy_func_provider(single_dataset_provider)

    """Finetune/evaluate."""
    finetune(train_valid_datasets_provider,
             model_provider,
             forward_step=cross_entropy_forward_step,
             end_of_epoch_callback_provider=metrics_func_provider,
             task_collate_fn=task_collate_fn)

def main():
    args = get_args()

Mostofa Patwary's avatar
Mostofa Patwary committed
275
276
277
278
279
    if args.task == 'RET-FINETUNE-NQ':
        from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset
    else:
        raise NotImplementedError('ORQA task {} is not implemented.'.format(
            args.task))
Mostofa Patwary's avatar
Mostofa Patwary committed
280

Mostofa Patwary's avatar
Mostofa Patwary committed
281
    orqa(Dataset)
mpatwary's avatar
mpatwary committed
282