build.py 3.5 KB
Newer Older
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.

import numpy as np
import os
import shutil
import torch
from tqdm import tqdm

from megatron import get_retro_args, print_rank_0
from tools.bert_embedding import DiskDataParallelBertEmbedder
from tools.retro.db.utils import (
    get_indexed_dataset_infos,
    get_merged_sampled_dataset,
    get_merged_train_dataset,
)
from tools.retro.external_libs import h5py
from tools.retro.index.factory import IndexFactory
from tools.retro.utils import GPTToTextDataset

from .utils import (
    get_training_data_dir,
    get_training_data_merged,
)


##################################################
# Train index.
##################################################


def get_empty_index_path():
    '''Path of empty index.'''
    args = get_retro_args()
    index = IndexFactory.get_index(args.retro_index_type)
    empty_index_path = index.get_empty_index_path()
    return empty_index_path


def embed_db():
    '''Embed DB chunks.

    Store chunks in blocks on disk. These blocks will later be merged into
    a single dataset for training the index.
    '''

    args = get_retro_args()

    # Get db dataset.
    gpt_dataset = get_merged_sampled_dataset()
    text_dataset = GPTToTextDataset(gpt_dataset)

    # Embed dataset.
    embedder = DiskDataParallelBertEmbedder(args.retro_bert_batch_size,
                                            args.retro_bert_max_chunk_length,
                                            args.retro_block_size,
                                            args.bert_embedder_type)
    embedder.embed_text_dataset("index", get_training_data_dir(), text_dataset)


def train_on_embeddings():
    '''Train index on embedded DB chunks.'''
    args = get_retro_args()
    index = IndexFactory.get_index(args.retro_index_type)
    index.train(get_training_data_merged)


def remove_embeddings():
    '''Remove embeddings after training.'''
    torch.distributed.barrier()
    if torch.distributed.get_rank() != 0:
        return
    empty_index_path = get_empty_index_path()
    assert os.path.isfile(empty_index_path)
    shutil.rmtree(get_training_data_dir(), ignore_errors=True)


def train_index():
    '''Train index on DB chunks.'''

    args = get_retro_args()

    # Check if trained index already exists.
    if not os.path.isfile(get_empty_index_path()):

        # Embed training chunks.
        embed_db()

        # Train index on embeddings.
        train_on_embeddings()

    # Wait for (single-process) training to complete.
    torch.distributed.barrier()

    # Remove embeddings.
    if args.retro_delete_index_training_embeddings:
        remove_embeddings()


##################################################
# Add to index.
##################################################


def add_to_index():
    '''Add DB chunks to index.'''

    args = get_retro_args()

    # Get index.
    index = IndexFactory.get_index(args.retro_index_type)

    # Get text dataset.
    gpt_dataset = get_merged_train_dataset()
    text_dataset = GPTToTextDataset(gpt_dataset)

    # Add to index.
    output_index_path = index.add(text_dataset)

    return output_index_path


##################################################
# Build index (train + add).
##################################################


def build_index():
    '''Build index.

    Building index involves sequentially running stages above:
    - Train index (on sampled training chunks).
    - Add to index (on all training chunks).
    '''

    # Train index.
    train_index()

    # Add to index.
    add_to_index()