build.py 4.82 KB
Newer Older
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.

import numpy as np
import os
import shutil
import torch
from tqdm import tqdm

from megatron import get_retro_args, print_rank_0
from tools.bert_embedding import DiskDataParallelBertEmbedder
from tools.retro.db.utils import (
    get_indexed_dataset_infos,
    get_merged_sampled_dataset,
    get_merged_train_dataset,
)
from tools.retro.external_libs import h5py
from tools.retro.index.factory import IndexFactory
from tools.retro.utils import GPTToTextDataset

from .utils import (
liangjing's avatar
v1  
liangjing committed
21
22
23
24
    get_training_data_block_dir,
    get_training_data_block_paths,
    get_training_data_merged_path,
    get_training_data_root_dir,
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
)


##################################################
# Train index.
##################################################


def get_empty_index_path():
    '''Path of empty index.'''
    args = get_retro_args()
    index = IndexFactory.get_index(args.retro_index_type)
    empty_index_path = index.get_empty_index_path()
    return empty_index_path


liangjing's avatar
v1  
liangjing committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def get_block_nload(block_path, load_fraction):
    with h5py.File(block_path) as fi:
        return int(load_fraction * fi["data"].shape[0])


def merge_embedding_blocks():

    if torch.distributed.get_rank() != 0:
        return

    args = get_retro_args()

    # Get block, merged paths.
    load_fraction = args.retro_index_train_load_fraction
    block_paths = get_training_data_block_paths()
    bin_path = get_training_data_merged_path()

    # Skip, if already built.
    if os.path.exists(bin_path):
        return

    # Merge blocks.
    with open(bin_path, "wb") as fo:
        byte_offset = 0
        for block_idx, block_path in \
            enumerate(tqdm(block_paths, "merge train embeddings")):
            with h5py.File(block_path) as fi:

                nload = get_block_nload(block_path, load_fraction)
                block = np.array(fi["data"][:nload], copy = False)

                fo.write(block.tobytes())

                byte_offset += block.size * block.itemsize
                fo.seek(byte_offset)


Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
78
79
80
81
82
83
84
85
86
def embed_db():
    '''Embed DB chunks.

    Store chunks in blocks on disk. These blocks will later be merged into
    a single dataset for training the index.
    '''

    args = get_retro_args()

liangjing's avatar
v1  
liangjing committed
87
88
89
90
    merged_train_data_path = get_training_data_merged_path()
    if os.path.exists(merged_train_data_path):
        return

Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
91
92
93
94
95
96
97
98
99
    # Get db dataset.
    gpt_dataset = get_merged_sampled_dataset()
    text_dataset = GPTToTextDataset(gpt_dataset)

    # Embed dataset.
    embedder = DiskDataParallelBertEmbedder(args.retro_bert_batch_size,
                                            args.retro_bert_max_chunk_length,
                                            args.retro_block_size,
                                            args.bert_embedder_type)
liangjing's avatar
v1  
liangjing committed
100
101
102
103
104
105
    embedder.embed_text_dataset("index",
                                get_training_data_block_dir(),
                                text_dataset)

    # Merge embeddings.
    merge_embedding_blocks()
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
106
107
108
109
110
111


def train_on_embeddings():
    '''Train index on embedded DB chunks.'''
    args = get_retro_args()
    index = IndexFactory.get_index(args.retro_index_type)
liangjing's avatar
v1  
liangjing committed
112
    index.train()
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
113
114
115
116
117
118
119
120
121


def remove_embeddings():
    '''Remove embeddings after training.'''
    torch.distributed.barrier()
    if torch.distributed.get_rank() != 0:
        return
    empty_index_path = get_empty_index_path()
    assert os.path.isfile(empty_index_path)
liangjing's avatar
v1  
liangjing committed
122
    shutil.rmtree(get_training_data_root_dir(), ignore_errors=True)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142


def train_index():
    '''Train index on DB chunks.'''

    args = get_retro_args()

    # Check if trained index already exists.
    if not os.path.isfile(get_empty_index_path()):

        # Embed training chunks.
        embed_db()

        # Train index on embeddings.
        train_on_embeddings()

    # Wait for (single-process) training to complete.
    torch.distributed.barrier()

    # Remove embeddings.
liangjing's avatar
v1  
liangjing committed
143
    if args.retro_index_delete_training_embeddings:
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
        remove_embeddings()


##################################################
# Add to index.
##################################################


def add_to_index():
    '''Add DB chunks to index.'''

    args = get_retro_args()

    # Get index.
    index = IndexFactory.get_index(args.retro_index_type)

    # Get text dataset.
    gpt_dataset = get_merged_train_dataset()
    text_dataset = GPTToTextDataset(gpt_dataset)

    # Add to index.
    output_index_path = index.add(text_dataset)

    return output_index_path


##################################################
# Build index (train + add).
##################################################


def build_index():
    '''Build index.

    Building index involves sequentially running stages above:
    - Train index (on sampled training chunks).
    - Add to index (on all training chunks).
    '''

    # Train index.
    train_index()

    # Add to index.
    add_to_index()