indexer.py 7 KB
Newer Older
Neel Kant's avatar
Neel Kant committed
1
2
3
4
import torch
import torch.distributed as dist
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

Neel Kant's avatar
Neel Kant committed
5
from megatron import get_args
Neel Kant's avatar
Neel Kant committed
6
7
8
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.dataset_utils import get_indexed_dataset_
Neel Kant's avatar
Neel Kant committed
9
10
from megatron.data.ict_dataset import ICTDataset
from megatron.data.realm_index import detach, BlockData
Neel Kant's avatar
Neel Kant committed
11
12
13
from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron
from megatron.training import get_model
Neel Kant's avatar
Neel Kant committed
14
from pretrain_ict import get_batch, general_ict_model_provider
Neel Kant's avatar
Neel Kant committed
15
16
17
18
19
20
21
22
23
24
25
26
27


def pprint(*args):
    print(*args, flush=True)


class IndexBuilder(object):
    """Object for taking one pass over a dataset and creating a BlockData of its embeddings"""
    def __init__(self):
        args = get_args()
        self.model = None
        self.dataloader = None
        self.block_data = None
Neel Kant's avatar
Neel Kant committed
28
29
30
31
32

        # need to know whether we're using a REALM checkpoint (args.load) or ICT checkpoint
        assert not (args.load and args.ict_load)
        self.using_realm_chkpt = args.ict_load is None

Neel Kant's avatar
Neel Kant committed
33
34
35
36
37
38
        self.load_attributes()
        self.is_main_builder = args.rank == 0
        self.iteration = self.total_processed = 0

    def load_attributes(self):
        """Load the necessary attributes: model, dataloader and empty BlockData"""
Neel Kant's avatar
Neel Kant committed
39
        self.model = load_ict_checkpoint(only_block_model=True, from_realm_chkpt=self.using_realm_chkpt)
Neel Kant's avatar
Neel Kant committed
40
41
        self.model.eval()
        self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
Neel Kant's avatar
Neel Kant committed
42
        self.block_data = BlockData(load_from_path=False)
Neel Kant's avatar
Neel Kant committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183

    def track_and_report_progress(self, batch_size):
        """Utility function for tracking progress"""
        self.iteration += 1
        self.total_processed += batch_size
        if self.iteration % 10 == 0:
            print('Batch {:10d} | Total {:10d}'.format(self.iteration, self.total_processed), flush=True)

    def build_and_save_index(self):
        """Goes through one epoch of the dataloader and adds all data to this instance's BlockData.

        The copy of BlockData is saved as a shard, which when run in a distributed setting will be
        consolidated by the rank 0 process and saved as a final pickled BlockData.
        """

        while True:
            try:
                # batch also has query_tokens and query_pad_data
                _, _, block_tokens, block_pad_mask, block_sample_data = get_batch(self.dataloader)
            except:
                break

            # detach, setup and add to BlockData
            unwrapped_model = self.model
            while not hasattr(unwrapped_model, 'embed_block'):
                unwrapped_model = unwrapped_model.module
            block_logits = detach(unwrapped_model.embed_block(block_tokens, block_pad_mask))

            detached_data = detach(block_sample_data)
            block_indices = detached_data[:, 3]
            block_metas = detached_data[:, :3]

            self.block_data.add_block_data(block_indices, block_logits, block_metas)
            self.track_and_report_progress(batch_size=block_tokens.shape[0])

        # This process signals to finalize its shard and then synchronize with the other processes
        self.block_data.save_shard()
        torch.distributed.barrier()
        del self.model

        # rank 0 process builds the final copy
        if self.is_main_builder:
            self.block_data.merge_shards_and_save()
        self.block_data.clear()


def load_ict_checkpoint(only_query_model=False, only_block_model=False, from_realm_chkpt=False):
    """load ICT checkpoints for indexing/retrieving. Arguments specify which parts of the state dict to actually use."""
    args = get_args()
    model = get_model(lambda: general_ict_model_provider(only_query_model, only_block_model))

    if isinstance(model, torchDDP):
        model = model.module

    load_path = args.load if from_realm_chkpt else args.ict_load

    tracker_filename = get_checkpoint_tracker_filename(load_path)
    with open(tracker_filename, 'r') as f:
        iteration = int(f.read().strip())

    # assert iteration > 0
    checkpoint_name = get_checkpoint_name(load_path, iteration, False)
    if mpu.get_data_parallel_rank() == 0:
        print('global rank {} is loading checkpoint {}'.format(
            torch.distributed.get_rank(), checkpoint_name))

    state_dict = torch.load(checkpoint_name, map_location='cpu')
    ict_state_dict = state_dict['model']
    if from_realm_chkpt:
        print(">>>> Attempting to get ict state dict from realm", flush=True)
        ict_state_dict = ict_state_dict['retriever']['ict_model']

    if only_query_model:
        ict_state_dict.pop('context_model')
    if only_block_model:
        ict_state_dict.pop('question_model')

    model.load_state_dict(ict_state_dict)
    torch.distributed.barrier()

    if mpu.get_data_parallel_rank() == 0:
        print(' successfully loaded {}'.format(checkpoint_name))

    return model


def get_ict_dataset(use_titles=True, query_in_block_prob=1):
    """Get a dataset which uses block samples mappings to get ICT/block indexing data"""
    args = get_args()
    block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
    titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)

    kwargs = dict(
        name='full',
        block_dataset=block_dataset,
        title_dataset=titles_dataset,
        data_prefix=args.data_path,
        num_epochs=1,
        max_num_samples=None,
        max_seq_length=args.seq_length,
        seed=1,
        query_in_block_prob=query_in_block_prob,
        use_titles=use_titles,
        use_one_sent_docs=True
    )
    dataset = ICTDataset(**kwargs)
    return dataset


def get_one_epoch_dataloader(dataset, batch_size=None):
    """Specifically one epoch to be used in an indexing job."""
    args = get_args()

    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    if batch_size is None:
        batch_size = args.batch_size
    global_batch_size = batch_size * world_size
    num_workers = args.num_workers

    sampler = torch.utils.data.SequentialSampler(dataset)
    # importantly, drop_last must be False to get all the data.
    batch_sampler = DistributedBatchSampler(sampler,
                                            batch_size=global_batch_size,
                                            drop_last=False,
                                            rank=rank,
                                            world_size=world_size)

    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=num_workers,
                                       pin_memory=True)


if __name__ == "__main__":
    # This usage is for basic (as opposed to realm async) indexing jobs.
    initialize_megatron(extra_args_provider=None,
                        args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
    index_builder = IndexBuilder()
    index_builder.build_and_save_index()