indexer.py 3.81 KB
Newer Older
Neel Kant's avatar
Neel Kant committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import torch.distributed as dist

from megatron import get_args
from megatron import mpu
from megatron.checkpointing import load_ict_checkpoint
from megatron.data.ict_dataset import get_ict_dataset
from megatron.data.realm_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, BlockData
from megatron.data.realm_dataset_utils import get_ict_batch
from megatron.model.realm_model import general_ict_model_provider
from megatron.training import get_model


class IndexBuilder(object):
    """Object for taking one pass over a dataset and creating a BlockData of its embeddings"""
    def __init__(self):
        args = get_args()
        self.model = None
        self.dataloader = None
        self.block_data = None

        # need to know whether we're using a REALM checkpoint (args.load) or ICT checkpoint
        assert not (args.load and args.ict_load)
        self.using_realm_chkpt = args.ict_load is None

        self.log_interval = args.indexer_log_interval
        self.batch_size = args.indexer_batch_size

        self.load_attributes()
        self.is_main_builder = mpu.get_data_parallel_rank() == 0
        self.num_total_builders = mpu.get_data_parallel_world_size()
        self.iteration = self.total_processed = 0

    def load_attributes(self):
        """Load the necessary attributes: model, dataloader and empty BlockData"""
        model = get_model(lambda: general_ict_model_provider(only_block_model=True))
        self.model = load_ict_checkpoint(model, only_block_model=True, from_realm_chkpt=self.using_realm_chkpt)
        self.model.eval()
        self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset(), self.batch_size))
        self.block_data = BlockData(load_from_path=False)

    def track_and_report_progress(self, batch_size):
        """Utility function for tracking progress"""
        self.iteration += 1
        self.total_processed += batch_size * self.num_total_builders
        if self.is_main_builder and self.iteration % self.log_interval == 0:
            print('Batch {:10d} | Total {:10d}'.format(self.iteration, self.total_processed), flush=True)

    def build_and_save_index(self):
        """Goes through one epoch of the dataloader and adds all data to this instance's BlockData.

        The copy of BlockData is saved as a shard, which when run in a distributed setting will be
        consolidated by the rank 0 process and saved as a final pickled BlockData.
        """

        while True:
            try:
                # batch also has query_tokens and query_pad_data
                _, _, block_tokens, block_pad_mask, block_sample_data = get_ict_batch(self.dataloader)
            except StopIteration:
                break

            unwrapped_model = self.model
            while not hasattr(unwrapped_model, 'embed_block'):
                unwrapped_model = unwrapped_model.module

            # detach, separate fields and add to BlockData
            block_logits = detach(unwrapped_model.embed_block(block_tokens, block_pad_mask))
            detached_data = detach(block_sample_data)

            # block_sample_data is a 2D array [batch x 4]
            # with columns [start_idx, end_idx, doc_idx, block_idx] same as class BlockSampleData
            block_indices = detached_data[:, 3]
            block_metas = detached_data[:, :3]

            self.block_data.add_block_data(block_indices, block_logits, block_metas)
            self.track_and_report_progress(batch_size=block_tokens.shape[0])

        # This process signals to finalize its shard and then synchronize with the other processes
        self.block_data.save_shard()
        torch.distributed.barrier()
        del self.model

        # rank 0 process builds the final copy
        if self.is_main_builder:
            self.block_data.merge_shards_and_save()
        self.block_data.clear()