indexer.py 5.9 KB
Newer Older
Mostofa Patwary's avatar
Mostofa Patwary committed
1
import sys
Neel Kant's avatar
Neel Kant committed
2
3
4
import torch
import torch.distributed as dist

Mostofa Patwary's avatar
Mostofa Patwary committed
5
from megatron import get_args, print_rank_0
Neel Kant's avatar
Neel Kant committed
6
from megatron import mpu
7
8
9
from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch
Mostofa Patwary's avatar
Mostofa Patwary committed
10
11
12
from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, OpenRetreivalDataStore
from megatron.model.biencoder_model import biencoder_model_provider
Neel Kant's avatar
Neel Kant committed
13
14
15
16
from megatron.training import get_model


class IndexBuilder(object):
17
18
19
20
    """
    Object for taking one pass over a dataset and creating a BlockData of its
    embeddings
    """
Neel Kant's avatar
Neel Kant committed
21
22
23
24
    def __init__(self):
        args = get_args()
        self.model = None
        self.dataloader = None
25
26
27
        self.evidence_embedder_obj = None
        self.biencoder_shared_query_context_model = \
            args.biencoder_shared_query_context_model
Mostofa Patwary's avatar
Mostofa Patwary committed
28
29
        self.pre_process = True
        self.post_process = True
Neel Kant's avatar
Neel Kant committed
30

31
32
        # need to know whether we're using a REALM checkpoint (args.load)
        # or ICT checkpoint
Neel Kant's avatar
Neel Kant committed
33
        assert not (args.load and args.ict_load)
34
        #self.using_realm_chkpt = args.ict_load is None
Neel Kant's avatar
Neel Kant committed
35
36
37
38
39
40
41
42
43
44

        self.log_interval = args.indexer_log_interval
        self.batch_size = args.indexer_batch_size

        self.load_attributes()
        self.is_main_builder = mpu.get_data_parallel_rank() == 0
        self.num_total_builders = mpu.get_data_parallel_world_size()
        self.iteration = self.total_processed = 0

    def load_attributes(self):
45
46
47
48
49
50
51
        """
        Load the necessary attributes: model, dataloader and empty BlockData
        """
        only_context_model = True
        if self.biencoder_shared_query_context_model:
            only_context_model = False

Mostofa Patwary's avatar
Mostofa Patwary committed
52
        model = get_model(lambda: biencoder_model_provider(only_context_model \
53
            = only_context_model, biencoder_shared_query_context_model = \
Mostofa Patwary's avatar
Mostofa Patwary committed
54
            self.biencoder_shared_query_context_model, \
Mostofa Patwary's avatar
Mostofa Patwary committed
55
56
57
58
59
60
            pre_process=self.pre_process, post_process=self.post_process))

        #model = biencoder_model_provider(only_context_model \
        #    = only_context_model, biencoder_shared_query_context_model = \
        #    self.biencoder_shared_query_context_model, \
        #    pre_process=self.pre_process, post_process=self.post_process)
61
62
63
64

        self.model = load_biencoder_checkpoint(model,
                only_context_model=only_context_model)

Mostofa Patwary's avatar
Mostofa Patwary committed
65
66
67
        #assert len(self.model) == 1
        #self.model[0].eval()
        self.model.eval()
68
69
70
71
72
73
74
75

        self.dataset = get_open_retrieval_wiki_dataset()
        self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \
            self.batch_size))

        self.evidence_embedder_obj = OpenRetreivalDataStore( \
            load_from_path=False)

Neel Kant's avatar
Neel Kant committed
76
    def track_and_report_progress(self, batch_size):
77
78
79
        """
        Utility function for tracking progress
        """
Neel Kant's avatar
Neel Kant committed
80
81
82
        self.iteration += 1
        self.total_processed += batch_size * self.num_total_builders
        if self.is_main_builder and self.iteration % self.log_interval == 0:
83
84
            print('Batch {:10d} | Total {:10d}'.format(self.iteration,
                self.total_processed), flush=True)
Neel Kant's avatar
Neel Kant committed
85
86

    def build_and_save_index(self):
87
88
89
        """
        Goes through one epoch of the dataloader and adds all data to this
        instance's BlockData.
Neel Kant's avatar
Neel Kant committed
90

91
92
93
        The copy of BlockData is saved as a shard, which when run in a
        distributed setting will be consolidated by the rank 0 process
        and saved as a final pickled BlockData.
Neel Kant's avatar
Neel Kant committed
94
        """
Mostofa Patwary's avatar
Mostofa Patwary committed
95
96
97
        #assert len(self.model) == 1
        #unwrapped_model = self.model[0]
        unwrapped_model = self.model
98
99
        while not hasattr(unwrapped_model, 'embed_text'):
            unwrapped_model = unwrapped_model.module
Mostofa Patwary's avatar
Mostofa Patwary committed
100
            print_rank_0("hasattr")
Neel Kant's avatar
Neel Kant committed
101
102
103
104

        while True:
            try:
                # batch also has query_tokens and query_pad_data
105
106
107
                row_id, context_tokens, context_mask, context_types, \
                    context_pad_mask = get_open_retrieval_batch( \
                    self.dataloader)
108
            except (StopIteration, IndexError):
Neel Kant's avatar
Neel Kant committed
109
110
                break

Mostofa Patwary's avatar
Mostofa Patwary committed
111
112
113
114
115
116
117
118
119
120
121
            print_rank_0(context_tokens)
            print_rank_0(context_mask)
            print_rank_0(context_types)
            #if torch.cuda.is_available():
            #    print_rank_0("cuda available")
            #print_rank_0(torch.cuda.current_device())
            #print_rank_0(torch.cuda.get_device_name())
            print_rank_0(next(unwrapped_model.parameters()).device)
            print_rank_0(next(unwrapped_model.context_model.parameters()).device)
            #print_rank_0("After get_open_retrieval_batch")

122
            # TODO: can we add with torch.no_grad() to reduce memory usage
Neel Kant's avatar
Neel Kant committed
123
            # detach, separate fields and add to BlockData
124
125
126
127
            assert context_mask.dtype == torch.bool
            context_logits = unwrapped_model.embed_text(
                unwrapped_model.context_model, context_tokens, context_mask,
                context_types)
Mostofa Patwary's avatar
Mostofa Patwary committed
128
129
130

            sys.exit()

131
132
133
134
135
136
137
138
139
            context_logits = detach(context_logits)
            row_id = detach(row_id)

            self.evidence_embedder_obj.add_block_data(row_id, context_logits)
            self.track_and_report_progress(batch_size=len(row_id))

        # This process signals to finalize its shard and then synchronize with
        # the other processes
        self.evidence_embedder_obj.save_shard()
Neel Kant's avatar
Neel Kant committed
140
141
142
143
144
        torch.distributed.barrier()
        del self.model

        # rank 0 process builds the final copy
        if self.is_main_builder:
145
            self.evidence_embedder_obj.merge_shards_and_save()
146
            # make sure that every single piece of data was embedded
147
148
149
150
151
152
            assert len(self.evidence_embedder_obj.embed_data) == \
                len(self.dataset)
        self.evidence_embedder_obj.clear()

        # complete building the final copy
        torch.distributed.barrier()