indexer.py 4.79 KB
Newer Older
Mostofa Patwary's avatar
Mostofa Patwary committed
1
import sys
Mostofa Patwary's avatar
Mostofa Patwary committed
2
import time
Neel Kant's avatar
Neel Kant committed
3
4
5
import torch
import torch.distributed as dist

Mostofa Patwary's avatar
Mostofa Patwary committed
6
from megatron import get_args, print_rank_0
7
from megatron.core import mpu
8
9
10
from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch
Mostofa Patwary's avatar
Mostofa Patwary committed
11
12
from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, OpenRetreivalDataStore
Mostofa Patwary's avatar
Mostofa Patwary committed
13
from megatron.model.biencoder_model import get_model_provider
Neel Kant's avatar
Neel Kant committed
14
15
16
17
from megatron.training import get_model


class IndexBuilder(object):
18
19
20
21
    """
    Object for taking one pass over a dataset and creating a BlockData of its
    embeddings
    """
Neel Kant's avatar
Neel Kant committed
22
23
24
25
    def __init__(self):
        args = get_args()
        self.model = None
        self.dataloader = None
26
27
28
        self.evidence_embedder_obj = None
        self.biencoder_shared_query_context_model = \
            args.biencoder_shared_query_context_model
Neel Kant's avatar
Neel Kant committed
29

30
31
        # need to know whether we're using a REALM checkpoint (args.load)
        # or ICT checkpoint
Neel Kant's avatar
Neel Kant committed
32
33
34
35
36
37
38
39
40
41
42
        assert not (args.load and args.ict_load)

        self.log_interval = args.indexer_log_interval
        self.batch_size = args.indexer_batch_size

        self.load_attributes()
        self.is_main_builder = mpu.get_data_parallel_rank() == 0
        self.num_total_builders = mpu.get_data_parallel_world_size()
        self.iteration = self.total_processed = 0

    def load_attributes(self):
43
44
45
46
47
48
49
        """
        Load the necessary attributes: model, dataloader and empty BlockData
        """
        only_context_model = True
        if self.biencoder_shared_query_context_model:
            only_context_model = False

Mostofa Patwary's avatar
Mostofa Patwary committed
50
51
52
        model = get_model(get_model_provider(only_context_model=\
            only_context_model, biencoder_shared_query_context_model=\
            self.biencoder_shared_query_context_model))
53
54
55
56

        self.model = load_biencoder_checkpoint(model,
                only_context_model=only_context_model)

Mostofa Patwary's avatar
Mostofa Patwary committed
57
58
        assert len(self.model) == 1
        self.model[0].eval()
59
60
61
62
63
64
65
66

        self.dataset = get_open_retrieval_wiki_dataset()
        self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \
            self.batch_size))

        self.evidence_embedder_obj = OpenRetreivalDataStore( \
            load_from_path=False)

Neel Kant's avatar
Neel Kant committed
67
    def track_and_report_progress(self, batch_size):
68
69
70
        """
        Utility function for tracking progress
        """
Neel Kant's avatar
Neel Kant committed
71
72
73
        self.iteration += 1
        self.total_processed += batch_size * self.num_total_builders
        if self.is_main_builder and self.iteration % self.log_interval == 0:
74
75
            print('Batch {:10d} | Total {:10d}'.format(self.iteration,
                self.total_processed), flush=True)
Neel Kant's avatar
Neel Kant committed
76
77

    def build_and_save_index(self):
78
79
80
        """
        Goes through one epoch of the dataloader and adds all data to this
        instance's BlockData.
Neel Kant's avatar
Neel Kant committed
81

82
83
84
        The copy of BlockData is saved as a shard, which when run in a
        distributed setting will be consolidated by the rank 0 process
        and saved as a final pickled BlockData.
Neel Kant's avatar
Neel Kant committed
85
        """
Mostofa Patwary's avatar
Mostofa Patwary committed
86
87
88
        assert len(self.model) == 1
        unwrapped_model = self.model[0]

89
90
        while not hasattr(unwrapped_model, 'embed_text'):
            unwrapped_model = unwrapped_model.module
Neel Kant's avatar
Neel Kant committed
91
92
93
94

        while True:
            try:
                # batch also has query_tokens and query_pad_data
95
96
97
                row_id, context_tokens, context_mask, context_types, \
                    context_pad_mask = get_open_retrieval_batch( \
                    self.dataloader)
98
            except (StopIteration, IndexError):
Neel Kant's avatar
Neel Kant committed
99
100
                break

101
            # TODO: can we add with torch.no_grad() to reduce memory usage
Neel Kant's avatar
Neel Kant committed
102
            # detach, separate fields and add to BlockData
103
104
105
106
            assert context_mask.dtype == torch.bool
            context_logits = unwrapped_model.embed_text(
                unwrapped_model.context_model, context_tokens, context_mask,
                context_types)
Mostofa Patwary's avatar
Mostofa Patwary committed
107

108
109
            context_logits = detach(context_logits)
            row_id = detach(row_id)
Mostofa Patwary's avatar
Mostofa Patwary committed
110

111
112
            self.evidence_embedder_obj.add_block_data(row_id, context_logits)
            self.track_and_report_progress(batch_size=len(row_id))
Mostofa Patwary's avatar
Mostofa Patwary committed
113

114
115
116
        # This process signals to finalize its shard and then synchronize with
        # the other processes
        self.evidence_embedder_obj.save_shard()
Neel Kant's avatar
Neel Kant committed
117
118
119
120
121
        torch.distributed.barrier()
        del self.model

        # rank 0 process builds the final copy
        if self.is_main_builder:
122
            self.evidence_embedder_obj.merge_shards_and_save()
123
            # make sure that every single piece of data was embedded
124
125
126
127
128
129
            assert len(self.evidence_embedder_obj.embed_data) == \
                len(self.dataset)
        self.evidence_embedder_obj.clear()

        # complete building the final copy
        torch.distributed.barrier()