indexer.py 4.2 KB
Newer Older
Mostofa Patwary's avatar
Mostofa Patwary committed
1
import sys
Neel Kant's avatar
Neel Kant committed
2
3
4
5
6
7
8
import torch
import torch.distributed as dist

from megatron import get_args
from megatron import mpu
from megatron.checkpointing import load_ict_checkpoint
from megatron.data.ict_dataset import get_ict_dataset
Mostofa Patwary's avatar
Mostofa Patwary committed
9
10
11
12
13
from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, OpenRetreivalDataStore
from megatron.data.biencoder_dataset_utils import get_ict_batch
from megatron.model.biencoder_model import biencoder_model_provider
#from megatron.model.realm_model import general_ict_model_provider
Neel Kant's avatar
Neel Kant committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from megatron.training import get_model


class IndexBuilder(object):
    """Object for taking one pass over a dataset and creating a BlockData of its embeddings"""
    def __init__(self):
        args = get_args()
        self.model = None
        self.dataloader = None
        self.block_data = None

        # need to know whether we're using a REALM checkpoint (args.load) or ICT checkpoint
        assert not (args.load and args.ict_load)
        self.using_realm_chkpt = args.ict_load is None

        self.log_interval = args.indexer_log_interval
        self.batch_size = args.indexer_batch_size

        self.load_attributes()
        self.is_main_builder = mpu.get_data_parallel_rank() == 0
        self.num_total_builders = mpu.get_data_parallel_world_size()
        self.iteration = self.total_processed = 0

    def load_attributes(self):
        """Load the necessary attributes: model, dataloader and empty BlockData"""
Mostofa Patwary's avatar
Mostofa Patwary committed
39
40
41
        model = get_model(lambda: biencoder_model_provider(only_context_model=True))
        self.model = load_ict_checkpoint(model, only_context_model=True, from_realm_chkpt=self.using_realm_chkpt)
        sys.exit()
Neel Kant's avatar
Neel Kant committed
42
        self.model.eval()
43
44
        self.dataset = get_ict_dataset()
        self.dataloader = iter(get_one_epoch_dataloader(self.dataset, self.batch_size))
Mostofa Patwary's avatar
Mostofa Patwary committed
45
46
47
48
        self.block_data = OpenRetreivalDataStore(load_from_path=False)
        print("load_attributes is done", flush=True)
        sys.exit()
 
Neel Kant's avatar
Neel Kant committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    def track_and_report_progress(self, batch_size):
        """Utility function for tracking progress"""
        self.iteration += 1
        self.total_processed += batch_size * self.num_total_builders
        if self.is_main_builder and self.iteration % self.log_interval == 0:
            print('Batch {:10d} | Total {:10d}'.format(self.iteration, self.total_processed), flush=True)

    def build_and_save_index(self):
        """Goes through one epoch of the dataloader and adds all data to this instance's BlockData.

        The copy of BlockData is saved as a shard, which when run in a distributed setting will be
        consolidated by the rank 0 process and saved as a final pickled BlockData.
        """

        while True:
            try:
                # batch also has query_tokens and query_pad_data
                _, _, block_tokens, block_pad_mask, block_sample_data = get_ict_batch(self.dataloader)
67
            except (StopIteration, IndexError):
Neel Kant's avatar
Neel Kant committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
                break

            unwrapped_model = self.model
            while not hasattr(unwrapped_model, 'embed_block'):
                unwrapped_model = unwrapped_model.module

            # detach, separate fields and add to BlockData
            block_logits = detach(unwrapped_model.embed_block(block_tokens, block_pad_mask))
            detached_data = detach(block_sample_data)

            # block_sample_data is a 2D array [batch x 4]
            # with columns [start_idx, end_idx, doc_idx, block_idx] same as class BlockSampleData
            block_indices = detached_data[:, 3]
            block_metas = detached_data[:, :3]

            self.block_data.add_block_data(block_indices, block_logits, block_metas)
            self.track_and_report_progress(batch_size=block_tokens.shape[0])

        # This process signals to finalize its shard and then synchronize with the other processes
        self.block_data.save_shard()
        torch.distributed.barrier()
        del self.model

        # rank 0 process builds the final copy
        if self.is_main_builder:
            self.block_data.merge_shards_and_save()
94
95
            # make sure that every single piece of data was embedded
            assert len(self.block_data.embed_data) == len(self.dataset)
Neel Kant's avatar
Neel Kant committed
96
        self.block_data.clear()