indexer.py 6.17 KB
Newer Older
Mostofa Patwary's avatar
Mostofa Patwary committed
1
import sys
Mostofa Patwary's avatar
Mostofa Patwary committed
2
import time
Neel Kant's avatar
Neel Kant committed
3
4
5
import torch
import torch.distributed as dist

Mostofa Patwary's avatar
Mostofa Patwary committed
6
from megatron import get_args, print_rank_0
Neel Kant's avatar
Neel Kant committed
7
from megatron import mpu
8
9
10
from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch
Mostofa Patwary's avatar
Mostofa Patwary committed
11
12
from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, OpenRetreivalDataStore
Mostofa Patwary's avatar
Mostofa Patwary committed
13
from megatron.model.biencoder_model import get_model_provider
Neel Kant's avatar
Neel Kant committed
14
15
16
17
from megatron.training import get_model


class IndexBuilder(object):
18
19
20
21
    """
    Object for taking one pass over a dataset and creating a BlockData of its
    embeddings
    """
Neel Kant's avatar
Neel Kant committed
22
23
24
25
    def __init__(self):
        args = get_args()
        self.model = None
        self.dataloader = None
26
27
28
        self.evidence_embedder_obj = None
        self.biencoder_shared_query_context_model = \
            args.biencoder_shared_query_context_model
Mostofa Patwary's avatar
Mostofa Patwary committed
29
30
        self.pre_process = True
        self.post_process = True
Neel Kant's avatar
Neel Kant committed
31

32
33
        # need to know whether we're using a REALM checkpoint (args.load)
        # or ICT checkpoint
Neel Kant's avatar
Neel Kant committed
34
        assert not (args.load and args.ict_load)
35
        #self.using_realm_chkpt = args.ict_load is None
Neel Kant's avatar
Neel Kant committed
36
37
38
39
40
41
42
43
44
45

        self.log_interval = args.indexer_log_interval
        self.batch_size = args.indexer_batch_size

        self.load_attributes()
        self.is_main_builder = mpu.get_data_parallel_rank() == 0
        self.num_total_builders = mpu.get_data_parallel_world_size()
        self.iteration = self.total_processed = 0

    def load_attributes(self):
46
47
48
        """
        Load the necessary attributes: model, dataloader and empty BlockData
        """
Mostofa Patwary's avatar
Mostofa Patwary committed
49
        args = get_args()
50
51
52
53
        only_context_model = True
        if self.biencoder_shared_query_context_model:
            only_context_model = False

Mostofa Patwary's avatar
Mostofa Patwary committed
54
55
        #args.only_context_model = only_context_model
        #args.only_query_model = False
Mostofa Patwary's avatar
Mostofa Patwary committed
56

57
        #model = get_model(biencoder_model_provider)
Mostofa Patwary's avatar
Mostofa Patwary committed
58

Mostofa Patwary's avatar
Mostofa Patwary committed
59
60
61
62
        model = get_model(get_model_provider(only_context_model=only_context_model, 
            biencoder_shared_query_context_model=self.biencoder_shared_query_context_model))

        #model = get_model(lambda: biencoder_model_provider(only_context_model \
Mostofa Patwary's avatar
Mostofa Patwary committed
63
        #model = get_model(lambda: biencoder_model_provider(only_context_model \
Mostofa Patwary's avatar
Mostofa Patwary committed
64
65
66
        #    = only_context_model, biencoder_shared_query_context_model = \
        #    self.biencoder_shared_query_context_model,
        #    pre_process=True, post_process=True)
67
68
69
70

        self.model = load_biencoder_checkpoint(model,
                only_context_model=only_context_model)

Mostofa Patwary's avatar
Mostofa Patwary committed
71
72
        assert len(self.model) == 1
        self.model[0].eval()
73
74
75
76
77
78
79
80

        self.dataset = get_open_retrieval_wiki_dataset()
        self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \
            self.batch_size))

        self.evidence_embedder_obj = OpenRetreivalDataStore( \
            load_from_path=False)

Neel Kant's avatar
Neel Kant committed
81
    def track_and_report_progress(self, batch_size):
82
83
84
        """
        Utility function for tracking progress
        """
Neel Kant's avatar
Neel Kant committed
85
86
87
        self.iteration += 1
        self.total_processed += batch_size * self.num_total_builders
        if self.is_main_builder and self.iteration % self.log_interval == 0:
88
89
            print('Batch {:10d} | Total {:10d}'.format(self.iteration,
                self.total_processed), flush=True)
Neel Kant's avatar
Neel Kant committed
90
91

    def build_and_save_index(self):
92
93
94
        """
        Goes through one epoch of the dataloader and adds all data to this
        instance's BlockData.
Neel Kant's avatar
Neel Kant committed
95

96
97
98
        The copy of BlockData is saved as a shard, which when run in a
        distributed setting will be consolidated by the rank 0 process
        and saved as a final pickled BlockData.
Neel Kant's avatar
Neel Kant committed
99
        """
Mostofa Patwary's avatar
Mostofa Patwary committed
100
101
102
        assert len(self.model) == 1
        unwrapped_model = self.model[0]

103
104
        while not hasattr(unwrapped_model, 'embed_text'):
            unwrapped_model = unwrapped_model.module
Neel Kant's avatar
Neel Kant committed
105

Mostofa Patwary's avatar
Mostofa Patwary committed
106
107
108
        counter = 0
        start_time = time.time()
        cur_time = start_time
Neel Kant's avatar
Neel Kant committed
109
        while True:
Mostofa Patwary's avatar
Mostofa Patwary committed
110
111
            #start_time = time.time()
            t1 = time.time()
Neel Kant's avatar
Neel Kant committed
112
113
            try:
                # batch also has query_tokens and query_pad_data
114
115
116
                row_id, context_tokens, context_mask, context_types, \
                    context_pad_mask = get_open_retrieval_batch( \
                    self.dataloader)
117
            except (StopIteration, IndexError):
Neel Kant's avatar
Neel Kant committed
118
119
                break

Mostofa Patwary's avatar
Mostofa Patwary committed
120
121
            #print_rank_0("get batch time {}".format(cur_time - time.time()))
            t2 = time.time()
122
            # TODO: can we add with torch.no_grad() to reduce memory usage
Neel Kant's avatar
Neel Kant committed
123
            # detach, separate fields and add to BlockData
124
125
126
127
            assert context_mask.dtype == torch.bool
            context_logits = unwrapped_model.embed_text(
                unwrapped_model.context_model, context_tokens, context_mask,
                context_types)
Mostofa Patwary's avatar
Mostofa Patwary committed
128

129
130
            context_logits = detach(context_logits)
            row_id = detach(row_id)
Mostofa Patwary's avatar
Mostofa Patwary committed
131
132
133
            #print_rank_0("embed text {}".format(cur_time - time.time()))
            t3 = time.time()
 
134
135
            self.evidence_embedder_obj.add_block_data(row_id, context_logits)
            self.track_and_report_progress(batch_size=len(row_id))
Mostofa Patwary's avatar
Mostofa Patwary committed
136
137
138
139
140
141
142
            #print_rank_0("add block time {}".format(cur_time - time.time()))
            t4 = time.time()
            counter += 1
            if counter % 1000 == 0:
                print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time))
                print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3))
                cur_time = time.time()
143
144
145
        # This process signals to finalize its shard and then synchronize with
        # the other processes
        self.evidence_embedder_obj.save_shard()
Neel Kant's avatar
Neel Kant committed
146
147
148
149
150
        torch.distributed.barrier()
        del self.model

        # rank 0 process builds the final copy
        if self.is_main_builder:
151
            self.evidence_embedder_obj.merge_shards_and_save()
152
            # make sure that every single piece of data was embedded
153
154
155
156
157
158
            assert len(self.evidence_embedder_obj.embed_data) == \
                len(self.dataset)
        self.evidence_embedder_obj.clear()

        # complete building the final copy
        torch.distributed.barrier()