"src/targets/vscode:/vscode.git/clone" did not exist on "8b81e1116f47fcf3b8dea39983cf88b54de284fd"
indexer.py 7.01 KB
Newer Older
Neel Kant's avatar
Neel Kant committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import os
import sys
import time

import torch
import torch.distributed as dist
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

from megatron import get_args, get_adlr_autoresume, print_rank_0
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.dataset_utils import get_indexed_dataset_
from megatron.data.realm_dataset import ICTDataset
from megatron.data.realm_dataset_utils import BlockSampleData
from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex
from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron
from megatron.training import get_model
from pretrain_bert_ict import get_batch, general_ict_model_provider


def pprint(*args):
    print(*args, flush=True)


class IndexBuilder(object):
    """Object for taking one pass over a dataset and creating a BlockData of its embeddings"""
    def __init__(self):
        args = get_args()
        self.model = None
        self.dataloader = None
        self.block_data = None
        self.load_attributes()
        self.is_main_builder = args.rank == 0
        self.iteration = self.total_processed = 0

    def load_attributes(self):
        """Load the necessary attributes: model, dataloader and empty BlockData"""
        # TODO: handle from_realm_chkpt correctly
        self.model = load_ict_checkpoint(only_block_model=True, from_realm_chkpt=False)
        self.model.eval()
        self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
        self.block_data = BlockData()

    def track_and_report_progress(self, batch_size):
        """Utility function for tracking progress"""
        self.iteration += 1
        self.total_processed += batch_size
        if self.iteration % 10 == 0:
            print('Batch {:10d} | Total {:10d}'.format(self.iteration, self.total_processed), flush=True)

    def build_and_save_index(self):
        """Goes through one epoch of the dataloader and adds all data to this instance's BlockData.

        The copy of BlockData is saved as a shard, which when run in a distributed setting will be
        consolidated by the rank 0 process and saved as a final pickled BlockData.
        """

        while True:
            try:
                # batch also has query_tokens and query_pad_data
                _, _, block_tokens, block_pad_mask, block_sample_data = get_batch(self.dataloader)
            except:
                break

            # detach, setup and add to BlockData
            unwrapped_model = self.model
            while not hasattr(unwrapped_model, 'embed_block'):
                unwrapped_model = unwrapped_model.module
            block_logits = detach(unwrapped_model.embed_block(block_tokens, block_pad_mask))

            detached_data = detach(block_sample_data)
            block_indices = detached_data[:, 3]
            block_metas = detached_data[:, :3]

            self.block_data.add_block_data(block_indices, block_logits, block_metas)
            self.track_and_report_progress(batch_size=block_tokens.shape[0])

        # This process signals to finalize its shard and then synchronize with the other processes
        self.block_data.save_shard()
        torch.distributed.barrier()
        del self.model

        # rank 0 process builds the final copy
        if self.is_main_builder:
            self.block_data.merge_shards_and_save()
        self.block_data.clear()


def load_ict_checkpoint(only_query_model=False, only_block_model=False, from_realm_chkpt=False):
    """load ICT checkpoints for indexing/retrieving. Arguments specify which parts of the state dict to actually use."""
    args = get_args()
    model = get_model(lambda: general_ict_model_provider(only_query_model, only_block_model))

    if isinstance(model, torchDDP):
        model = model.module

    load_path = args.load if from_realm_chkpt else args.ict_load

    tracker_filename = get_checkpoint_tracker_filename(load_path)
    with open(tracker_filename, 'r') as f:
        iteration = int(f.read().strip())

    # assert iteration > 0
    checkpoint_name = get_checkpoint_name(load_path, iteration, False)
    if mpu.get_data_parallel_rank() == 0:
        print('global rank {} is loading checkpoint {}'.format(
            torch.distributed.get_rank(), checkpoint_name))

    state_dict = torch.load(checkpoint_name, map_location='cpu')
    ict_state_dict = state_dict['model']
    if from_realm_chkpt:
        print(">>>> Attempting to get ict state dict from realm", flush=True)
        ict_state_dict = ict_state_dict['retriever']['ict_model']

    if only_query_model:
        ict_state_dict.pop('context_model')
    if only_block_model:
        ict_state_dict.pop('question_model')

    model.load_state_dict(ict_state_dict)
    torch.distributed.barrier()

    if mpu.get_data_parallel_rank() == 0:
        print(' successfully loaded {}'.format(checkpoint_name))

    return model


def get_ict_dataset(use_titles=True, query_in_block_prob=1):
    """Get a dataset which uses block samples mappings to get ICT/block indexing data"""
    args = get_args()
    block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
    titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)

    kwargs = dict(
        name='full',
        block_dataset=block_dataset,
        title_dataset=titles_dataset,
        data_prefix=args.data_path,
        num_epochs=1,
        max_num_samples=None,
        max_seq_length=args.seq_length,
        short_seq_prob=0.0001,  # doesn't matter
        seed=1,
        query_in_block_prob=query_in_block_prob,
        use_titles=use_titles,
        use_one_sent_docs=True
    )
    dataset = ICTDataset(**kwargs)
    return dataset


def get_one_epoch_dataloader(dataset, batch_size=None):
    """Specifically one epoch to be used in an indexing job."""
    args = get_args()

    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    if batch_size is None:
        batch_size = args.batch_size
    global_batch_size = batch_size * world_size
    num_workers = args.num_workers

    sampler = torch.utils.data.SequentialSampler(dataset)
    # importantly, drop_last must be False to get all the data.
    batch_sampler = DistributedBatchSampler(sampler,
                                            batch_size=global_batch_size,
                                            drop_last=False,
                                            rank=rank,
                                            world_size=world_size)

    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=num_workers,
                                       pin_memory=True)


if __name__ == "__main__":
    # This usage is for basic (as opposed to realm async) indexing jobs.
    initialize_megatron(extra_args_provider=None,
                        args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
    index_builder = IndexBuilder()
    index_builder.build_and_save_index()