utils.py 4.17 KB
Newer Older
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.

from collections import defaultdict
import glob
import json
import numpy as np
import os
from tqdm import tqdm

from megatron import get_retro_args, print_rank_0
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from tools.retro.external_libs import h5py

from .dataset import DBDataset


def get_base_db_workdir():
    '''Sub-directory for DB data.'''
    args = get_retro_args()
    return os.path.join(args.retro_workdir, "db")


def get_indexed_dataset_infos_path():
    '''Path to indexed dataset meta-infos.'''
    return os.path.join(get_base_db_workdir(), "indexed_dataset_infos.json")


def save_indexed_dataset_infos(indexed_dataset_infos):
    '''Save dataset order & meta-info.'''

    # Remove 'dataset' field.
    clean_infos = []
    for info in indexed_dataset_infos:
        info = dict(info)
        del info["dataset"]
        clean_infos.append(info)

    # Save.
    with open(get_indexed_dataset_infos_path(), "w") as f:
        json.dump(clean_infos, f, indent=4)


def get_indexed_dataset_infos():
    '''Load indexed dataset meta-infos.'''

    # Load json.
    path = get_indexed_dataset_infos_path()
    with open(path) as f:
        infos = json.load(f)

    # Add indexed datasets.
    for info in infos:
        info["dataset"] = make_indexed_dataset(info["prefix"], "mmap", True)

    return infos


def get_individual_db_dir(name):
    '''Individual DB's directory.'''
liangjing's avatar
v1  
liangjing committed
60
    return os.path.join(get_base_db_workdir(), "individual", name)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
61
62


liangjing's avatar
v1  
liangjing committed
63
def get_individual_chunk_db(ds_id, ds_info):
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
64
65
66
    '''Load individual dataset's chunk DB.'''
    db_paths = sorted(glob.glob(ds_info["db_dir"] + "/*hdf5"))
    # *Note*: convert to dataset, rather than copying to memory.
liangjing's avatar
v1  
liangjing committed
67
    db = np.zeros((ds_info["n_chunks"], 5), dtype="uint32")
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    db[:, 0] = ds_id
    start_idx = 0
    for db_path in db_paths:
        f = h5py.File(db_path, "r")
        n_chunks_current = f["chunks_valid"].shape[0]
        db[start_idx:(start_idx+n_chunks_current), 1:] = f["chunks_valid"]
        start_idx += n_chunks_current
        f.close()

    assert start_idx == ds_info["n_chunks"]

    return db


liangjing's avatar
v1  
liangjing committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
def get_individual_doc_offsets(ds_id, ds_info):
    '''Load individual dataset's chunk DB.'''
    paths = sorted(glob.glob(ds_info["db_dir"] + "/*hdf5"))
    # *Note*: convert to dataset, rather than copying to memory.
    doc_offsets = np.zeros((ds_info["n_docs"], 3), dtype="uint64")
    doc_offsets[:, 0] = ds_id
    start_idx = 0
    start_offset = 0
    for path in paths:
        with h5py.File(path) as f:
            current_doc_offsets = np.copy(f["doc_offsets"])
            current_doc_offsets[:, 1] += start_offset
            current_ndocs = current_doc_offsets.shape[0]
            doc_offsets[start_idx:(start_idx+current_ndocs), 1:] = \
                current_doc_offsets
            start_idx += current_ndocs
            start_offset = current_doc_offsets[-1, 1].item()

    return doc_offsets


Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def get_merged_db_path_map():
    '''Paths to merged datasets.'''
    base_dir = get_base_db_workdir()
    return {
        "sampled" : os.path.join(base_dir, "merged", "sampled.hdf5"),
        "train" : os.path.join(base_dir, "merged", "train.hdf5"),
        "valid" : os.path.join(base_dir, "merged", "valid.hdf5"),
    }


def get_merged_dataset(db_type, indexed_dataset_infos=None):
    '''Get merged dataset.'''

    args = get_retro_args()

    if not indexed_dataset_infos:
        indexed_dataset_infos = get_indexed_dataset_infos()

    # Load chunks.
    db_path = get_merged_db_path_map()[db_type]
    f = h5py.File(db_path, "r")
    chunks = f["chunks"]

    # DB dataset.
    indexed_datasets = [ info["dataset"] for info in indexed_dataset_infos ]
    dataset = DBDataset(db_path, indexed_datasets, chunks,
                        args.retro_gpt_chunk_length)

    return dataset


def get_merged_sampled_dataset(indexed_dataset_infos=None):
    return get_merged_dataset("sampled", indexed_dataset_infos)


def get_merged_train_dataset(indexed_dataset_infos=None):
    return get_merged_dataset("train", indexed_dataset_infos)


def get_merged_valid_dataset(indexed_dataset_infos=None):
    return get_merged_dataset("valid", indexed_dataset_infos)