indexer.py 5.39 KB
Newer Older
Mostofa Patwary's avatar
Mostofa Patwary committed
1
import sys
Neel Kant's avatar
Neel Kant committed
2
3
4
import torch
import torch.distributed as dist

Mostofa Patwary's avatar
Mostofa Patwary committed
5
from megatron import get_args, print_rank_0
Neel Kant's avatar
Neel Kant committed
6
from megatron import mpu
7
8
9
from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch
Mostofa Patwary's avatar
Mostofa Patwary committed
10
11
from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, OpenRetreivalDataStore
Mostofa Patwary's avatar
Mostofa Patwary committed
12
from megatron.model.biencoder_model import get_model_provider
Neel Kant's avatar
Neel Kant committed
13
14
15
16
from megatron.training import get_model


class IndexBuilder(object):
17
18
19
20
    """
    Object for taking one pass over a dataset and creating a BlockData of its
    embeddings
    """
Neel Kant's avatar
Neel Kant committed
21
22
23
24
    def __init__(self):
        args = get_args()
        self.model = None
        self.dataloader = None
25
26
27
        self.evidence_embedder_obj = None
        self.biencoder_shared_query_context_model = \
            args.biencoder_shared_query_context_model
Mostofa Patwary's avatar
Mostofa Patwary committed
28
29
        self.pre_process = True
        self.post_process = True
Neel Kant's avatar
Neel Kant committed
30

31
32
        # need to know whether we're using a REALM checkpoint (args.load)
        # or ICT checkpoint
Neel Kant's avatar
Neel Kant committed
33
        assert not (args.load and args.ict_load)
34
        #self.using_realm_chkpt = args.ict_load is None
Neel Kant's avatar
Neel Kant committed
35
36
37
38
39
40
41
42
43
44

        self.log_interval = args.indexer_log_interval
        self.batch_size = args.indexer_batch_size

        self.load_attributes()
        self.is_main_builder = mpu.get_data_parallel_rank() == 0
        self.num_total_builders = mpu.get_data_parallel_world_size()
        self.iteration = self.total_processed = 0

    def load_attributes(self):
45
46
47
        """
        Load the necessary attributes: model, dataloader and empty BlockData
        """
Mostofa Patwary's avatar
Mostofa Patwary committed
48
        args = get_args()
49
50
51
52
        only_context_model = True
        if self.biencoder_shared_query_context_model:
            only_context_model = False

Mostofa Patwary's avatar
Mostofa Patwary committed
53
54
        #args.only_context_model = only_context_model
        #args.only_query_model = False
Mostofa Patwary's avatar
Mostofa Patwary committed
55

56
        #model = get_model(biencoder_model_provider)
Mostofa Patwary's avatar
Mostofa Patwary committed
57

Mostofa Patwary's avatar
Mostofa Patwary committed
58
59
60
61
        model = get_model(get_model_provider(only_context_model=only_context_model, 
            biencoder_shared_query_context_model=self.biencoder_shared_query_context_model))

        #model = get_model(lambda: biencoder_model_provider(only_context_model \
Mostofa Patwary's avatar
Mostofa Patwary committed
62
        #model = get_model(lambda: biencoder_model_provider(only_context_model \
Mostofa Patwary's avatar
Mostofa Patwary committed
63
64
65
        #    = only_context_model, biencoder_shared_query_context_model = \
        #    self.biencoder_shared_query_context_model,
        #    pre_process=True, post_process=True)
66
67
68
69

        self.model = load_biencoder_checkpoint(model,
                only_context_model=only_context_model)

Mostofa Patwary's avatar
Mostofa Patwary committed
70
71
        assert len(self.model) == 1
        self.model[0].eval()
72
73
74
75
76
77
78
79

        self.dataset = get_open_retrieval_wiki_dataset()
        self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \
            self.batch_size))

        self.evidence_embedder_obj = OpenRetreivalDataStore( \
            load_from_path=False)

Neel Kant's avatar
Neel Kant committed
80
    def track_and_report_progress(self, batch_size):
81
82
83
        """
        Utility function for tracking progress
        """
Neel Kant's avatar
Neel Kant committed
84
85
86
        self.iteration += 1
        self.total_processed += batch_size * self.num_total_builders
        if self.is_main_builder and self.iteration % self.log_interval == 0:
87
88
            print('Batch {:10d} | Total {:10d}'.format(self.iteration,
                self.total_processed), flush=True)
Neel Kant's avatar
Neel Kant committed
89
90

    def build_and_save_index(self):
91
92
93
        """
        Goes through one epoch of the dataloader and adds all data to this
        instance's BlockData.
Neel Kant's avatar
Neel Kant committed
94

95
96
97
        The copy of BlockData is saved as a shard, which when run in a
        distributed setting will be consolidated by the rank 0 process
        and saved as a final pickled BlockData.
Neel Kant's avatar
Neel Kant committed
98
        """
Mostofa Patwary's avatar
Mostofa Patwary committed
99
100
101
        assert len(self.model) == 1
        unwrapped_model = self.model[0]

102
103
        while not hasattr(unwrapped_model, 'embed_text'):
            unwrapped_model = unwrapped_model.module
Neel Kant's avatar
Neel Kant committed
104
105
106
107

        while True:
            try:
                # batch also has query_tokens and query_pad_data
108
109
110
                row_id, context_tokens, context_mask, context_types, \
                    context_pad_mask = get_open_retrieval_batch( \
                    self.dataloader)
111
            except (StopIteration, IndexError):
Neel Kant's avatar
Neel Kant committed
112
113
                break

114
            # TODO: can we add with torch.no_grad() to reduce memory usage
Neel Kant's avatar
Neel Kant committed
115
            # detach, separate fields and add to BlockData
116
117
118
119
            assert context_mask.dtype == torch.bool
            context_logits = unwrapped_model.embed_text(
                unwrapped_model.context_model, context_tokens, context_mask,
                context_types)
Mostofa Patwary's avatar
Mostofa Patwary committed
120

121
122
123
124
125
126
127
128
129
            context_logits = detach(context_logits)
            row_id = detach(row_id)

            self.evidence_embedder_obj.add_block_data(row_id, context_logits)
            self.track_and_report_progress(batch_size=len(row_id))

        # This process signals to finalize its shard and then synchronize with
        # the other processes
        self.evidence_embedder_obj.save_shard()
Neel Kant's avatar
Neel Kant committed
130
131
132
133
134
        torch.distributed.barrier()
        del self.model

        # rank 0 process builds the final copy
        if self.is_main_builder:
135
            self.evidence_embedder_obj.merge_shards_and_save()
136
            # make sure that every single piece of data was embedded
137
138
139
140
141
142
            assert len(self.evidence_embedder_obj.embed_data) == \
                len(self.dataset)
        self.evidence_embedder_obj.clear()

        # complete building the final copy
        torch.distributed.barrier()