indexer.py 4.91 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
Mostofa Patwary's avatar
Mostofa Patwary committed
2
import sys
Mostofa Patwary's avatar
Mostofa Patwary committed
3
import time
Neel Kant's avatar
Neel Kant committed
4
5
6
import torch
import torch.distributed as dist

xingjinliang's avatar
xingjinliang committed
7
from megatron.training import get_args, print_rank_0
8
from megatron.core import mpu
xingjinliang's avatar
xingjinliang committed
9
10
11
12
13
14
from megatron.training.checkpointing import load_biencoder_checkpoint
from megatron.legacy.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from megatron.legacy.data.orqa_wiki_dataset import get_open_retrieval_batch
from megatron.legacy.data.biencoder_dataset_utils import get_one_epoch_dataloader
from megatron.legacy.data.realm_index import detach, OpenRetreivalDataStore
from megatron.legacy.model.biencoder_model import get_model_provider
Neel Kant's avatar
Neel Kant committed
15
16
17
18
from megatron.training import get_model


class IndexBuilder(object):
19
20
21
22
    """
    Object for taking one pass over a dataset and creating a BlockData of its
    embeddings
    """
Neel Kant's avatar
Neel Kant committed
23
24
25
26
    def __init__(self):
        args = get_args()
        self.model = None
        self.dataloader = None
27
28
29
        self.evidence_embedder_obj = None
        self.biencoder_shared_query_context_model = \
            args.biencoder_shared_query_context_model
Neel Kant's avatar
Neel Kant committed
30

31
32
        # need to know whether we're using a REALM checkpoint (args.load)
        # or ICT checkpoint
Neel Kant's avatar
Neel Kant committed
33
34
35
36
37
38
39
40
41
42
43
        assert not (args.load and args.ict_load)

        self.log_interval = args.indexer_log_interval
        self.batch_size = args.indexer_batch_size

        self.load_attributes()
        self.is_main_builder = mpu.get_data_parallel_rank() == 0
        self.num_total_builders = mpu.get_data_parallel_world_size()
        self.iteration = self.total_processed = 0

    def load_attributes(self):
44
45
46
47
48
49
50
        """
        Load the necessary attributes: model, dataloader and empty BlockData
        """
        only_context_model = True
        if self.biencoder_shared_query_context_model:
            only_context_model = False

Mostofa Patwary's avatar
Mostofa Patwary committed
51
52
53
        model = get_model(get_model_provider(only_context_model=\
            only_context_model, biencoder_shared_query_context_model=\
            self.biencoder_shared_query_context_model))
54
55
56
57

        self.model = load_biencoder_checkpoint(model,
                only_context_model=only_context_model)

Mostofa Patwary's avatar
Mostofa Patwary committed
58
59
        assert len(self.model) == 1
        self.model[0].eval()
60
61
62
63
64
65
66
67

        self.dataset = get_open_retrieval_wiki_dataset()
        self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \
            self.batch_size))

        self.evidence_embedder_obj = OpenRetreivalDataStore( \
            load_from_path=False)

Neel Kant's avatar
Neel Kant committed
68
    def track_and_report_progress(self, batch_size):
69
70
71
        """
        Utility function for tracking progress
        """
Neel Kant's avatar
Neel Kant committed
72
73
74
        self.iteration += 1
        self.total_processed += batch_size * self.num_total_builders
        if self.is_main_builder and self.iteration % self.log_interval == 0:
75
76
            print('Batch {:10d} | Total {:10d}'.format(self.iteration,
                self.total_processed), flush=True)
Neel Kant's avatar
Neel Kant committed
77
78

    def build_and_save_index(self):
79
80
81
        """
        Goes through one epoch of the dataloader and adds all data to this
        instance's BlockData.
Neel Kant's avatar
Neel Kant committed
82

83
84
85
        The copy of BlockData is saved as a shard, which when run in a
        distributed setting will be consolidated by the rank 0 process
        and saved as a final pickled BlockData.
Neel Kant's avatar
Neel Kant committed
86
        """
Mostofa Patwary's avatar
Mostofa Patwary committed
87
88
89
        assert len(self.model) == 1
        unwrapped_model = self.model[0]

90
91
        while not hasattr(unwrapped_model, 'embed_text'):
            unwrapped_model = unwrapped_model.module
Neel Kant's avatar
Neel Kant committed
92
93
94
95

        while True:
            try:
                # batch also has query_tokens and query_pad_data
96
97
98
                row_id, context_tokens, context_mask, context_types, \
                    context_pad_mask = get_open_retrieval_batch( \
                    self.dataloader)
99
            except (StopIteration, IndexError):
Neel Kant's avatar
Neel Kant committed
100
101
                break

102
            # TODO: can we add with torch.no_grad() to reduce memory usage
Neel Kant's avatar
Neel Kant committed
103
            # detach, separate fields and add to BlockData
104
105
106
107
            assert context_mask.dtype == torch.bool
            context_logits = unwrapped_model.embed_text(
                unwrapped_model.context_model, context_tokens, context_mask,
                context_types)
Mostofa Patwary's avatar
Mostofa Patwary committed
108

109
110
            context_logits = detach(context_logits)
            row_id = detach(row_id)
Mostofa Patwary's avatar
Mostofa Patwary committed
111

112
113
            self.evidence_embedder_obj.add_block_data(row_id, context_logits)
            self.track_and_report_progress(batch_size=len(row_id))
Mostofa Patwary's avatar
Mostofa Patwary committed
114

115
116
117
        # This process signals to finalize its shard and then synchronize with
        # the other processes
        self.evidence_embedder_obj.save_shard()
Neel Kant's avatar
Neel Kant committed
118
119
120
121
122
        torch.distributed.barrier()
        del self.model

        # rank 0 process builds the final copy
        if self.is_main_builder:
123
            self.evidence_embedder_obj.merge_shards_and_save()
124
            # make sure that every single piece of data was embedded
125
126
127
128
129
130
            assert len(self.evidence_embedder_obj.embed_data) == \
                len(self.dataset)
        self.evidence_embedder_obj.clear()

        # complete building the final copy
        torch.distributed.barrier()