indexer.py 6.87 KB
Newer Older
1
import os
2
import sys
3
4
import time

Neel Kant's avatar
Neel Kant committed
5
6
7
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

8
from megatron import get_args, get_adlr_autoresume, print_rank_0
Neel Kant's avatar
Neel Kant committed
9
10
11
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.bert_dataset import get_indexed_dataset_
12
from megatron.data.realm_dataset import ICTDataset
13
from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex
Neel Kant's avatar
Neel Kant committed
14
15
from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron
Neel Kant's avatar
Neel Kant committed
16
from megatron.model import REALMRetriever
Neel Kant's avatar
Neel Kant committed
17
from megatron.training import get_model
18
from megatron.utils import check_adlr_autoresume_termination
Neel Kant's avatar
Neel Kant committed
19
from pretrain_bert_ict import get_batch, model_provider
Neel Kant's avatar
Neel Kant committed
20
from indexer_utils import set_index_com_file_ready, set_model_com_file_not_ready, check_model_com_file_ready
Neel Kant's avatar
Neel Kant committed
21
22


Neel Kant's avatar
Neel Kant committed
23
def test_retriever():
24
    # TODO: Update this because it's outdated and definitely won't run.
Neel Kant's avatar
Neel Kant committed
25
26
    initialize_megatron(extra_args_provider=None,
                        args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
Neel Kant's avatar
Neel Kant committed
27
    args = get_args()
28
    model = load_ict_checkpoint()
Neel Kant's avatar
Neel Kant committed
29
    model.eval()
Neel Kant's avatar
Neel Kant committed
30
    dataset = get_ict_dataset()
Neel Kant's avatar
Neel Kant committed
31
32
33
34

    block_data = BlockData.load_from_file(args.block_data_path)
    mips_index = FaissMIPSIndex('flat_ip', 128)
    mips_index.add_block_embed_data(block_data)
35
    retriever = REALMRetriever(model, dataset, block_data, mips_index, top_k=5)
Neel Kant's avatar
Neel Kant committed
36
37
38
39
40
41
42
43
44
45

    strs = [
        "The last monarch from the house of windsor",
        "married to Elvis Presley",
        "tallest building in the world today",
        "who makes graphics cards"
    ]

    for s in strs:
        retriever.retrieve_evidence_blocks_text(s)
Neel Kant's avatar
Neel Kant committed
46
47


Neel Kant's avatar
Neel Kant committed
48
def main():
Neel Kant's avatar
Neel Kant committed
49
50
51
    initialize_megatron(extra_args_provider=None,
                        args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
    args = get_args()
Neel Kant's avatar
Neel Kant committed
52
    while True:
Neel Kant's avatar
Neel Kant committed
53
54
55
56
        try:
            model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=True)
        except:
            model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=False)
57
58
59
60
61
62
63
64
65
66
67
68
69
        model.eval()
        dataset = get_ict_dataset()
        data_iter = iter(get_one_epoch_dataloader(dataset))
        all_block_data = BlockData()

        i = 1
        total = 0
        while True:
            with torch.no_grad():
                try:
                    query_tokens, query_pad_mask, \
                    block_tokens, block_pad_mask, block_index_data = get_batch(data_iter)
                except:
70
71
                    break

72
73
74
75
76
77
78
79
80
                block_index_data = detach(block_index_data)
                block_indices = block_index_data[:, 3]
                block_meta = block_index_data[:, :3]

                block_logits = detach(model(None, None, block_tokens, block_pad_mask, only_block=True))
                all_block_data.add_block_data(block_indices, block_logits, block_meta)

                total += block_indices.size
                i += 1
Neel Kant's avatar
Neel Kant committed
81
                if i % 2000 == 0:
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
                    print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
                    if args.debug:
                        break

        all_block_data.save_shard(args.rank)
        torch.distributed.barrier()
        del model

        if args.rank == 0:
            all_block_data.consolidate_shards_and_save()
        else:
            all_block_data.clear()

        set_index_com_file_ready()
        torch.distributed.barrier()
97
98
99
100
101
102
103
104
105
106
107
108
        if args.async_indexer:
            while not check_model_com_file_ready():
                time.sleep(5)
                autoresume = get_adlr_autoresume()
                if autoresume.termination_requested():
                    print_rank_0(">>> autoresume termination request found!")
                    if torch.distributed.get_rank() == 0:
                        autoresume.request_resume()
                    print_rank_0(">>> training terminated. Returning")
                    sys.exit(0)

            set_model_com_file_not_ready()
109
110
111


def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False):
Neel Kant's avatar
Neel Kant committed
112
    args = get_args()
Neel Kant's avatar
Neel Kant committed
113
    model = get_model(lambda: model_provider(only_query_model, only_block_model))
Neel Kant's avatar
Neel Kant committed
114

115
116
    load_path = args.load if from_realm_chkpt else args.ict_load

Neel Kant's avatar
Neel Kant committed
117
118
    if isinstance(model, torchDDP):
        model = model.module
119
    tracker_filename = get_checkpoint_tracker_filename(load_path)
Neel Kant's avatar
Neel Kant committed
120
121
122
123
    with open(tracker_filename, 'r') as f:
        iteration = int(f.read().strip())

    assert iteration > 0
124
    checkpoint_name = get_checkpoint_name(load_path, iteration, False)
Neel Kant's avatar
Neel Kant committed
125
126
127
128
129
    if mpu.get_data_parallel_rank() == 0:
        print('global rank {} is loading checkpoint {}'.format(
            torch.distributed.get_rank(), checkpoint_name))

    state_dict = torch.load(checkpoint_name, map_location='cpu')
130
131
132
133
    ict_state_dict = state_dict['model']
    if from_realm_chkpt:
        ict_state_dict = ict_state_dict['retriever']['ict_model']

Neel Kant's avatar
Neel Kant committed
134
    if only_query_model:
135
        ict_state_dict.pop('context_model')
Neel Kant's avatar
Neel Kant committed
136
    if only_block_model:
137
        ict_state_dict.pop('question_model')
Neel Kant's avatar
Neel Kant committed
138
139
    if no_grad:
        with torch.no_grad():
140
            model.load_state_dict(ict_state_dict)
Neel Kant's avatar
Neel Kant committed
141
    else:
142
        model.load_state_dict(ict_state_dict)
Neel Kant's avatar
Neel Kant committed
143
144
145
146
147
148
149
150
    torch.distributed.barrier()

    if mpu.get_data_parallel_rank() == 0:
        print(' successfully loaded {}'.format(checkpoint_name))

    return model


151
def get_ict_dataset(use_titles=True):
Neel Kant's avatar
Neel Kant committed
152
    args = get_args()
Neel Kant's avatar
Neel Kant committed
153
    block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
Neel Kant's avatar
Neel Kant committed
154
    titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
Neel Kant's avatar
Neel Kant committed
155
156
157

    kwargs = dict(
        name='full',
Neel Kant's avatar
Neel Kant committed
158
159
        block_dataset=block_dataset,
        title_dataset=titles_dataset,
Neel Kant's avatar
Neel Kant committed
160
        data_prefix=args.data_path,
Neel Kant's avatar
Neel Kant committed
161
162
        num_epochs=1,
        max_num_samples=None,
Neel Kant's avatar
Neel Kant committed
163
164
        max_seq_length=288,  # doesn't matter
        short_seq_prob=0.0001,  # doesn't matter
165
        seed=1,
Neel Kant's avatar
Neel Kant committed
166
        query_in_block_prob=1,
167
        use_titles=use_titles
Neel Kant's avatar
Neel Kant committed
168
    )
169
    dataset = ICTDataset(**kwargs)
Neel Kant's avatar
Neel Kant committed
170
171
172
    return dataset


Neel Kant's avatar
Neel Kant committed
173
def get_one_epoch_dataloader(dataset):
Neel Kant's avatar
Neel Kant committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    args = get_args()

    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    global_batch_size = args.batch_size * world_size
    num_workers = args.num_workers

    sampler = torch.utils.data.SequentialSampler(dataset)
    batch_sampler = DistributedBatchSampler(sampler,
                                            batch_size=global_batch_size,
                                            drop_last=True,
                                            rank=rank,
                                            world_size=world_size)

    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=num_workers,
                                       pin_memory=True)


if __name__ == "__main__":
Neel Kant's avatar
Neel Kant committed
195
    main()