utils.py 1.8 KB
Newer Older
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.

import concurrent
import gc
import glob
import numpy as np
import os
import psutil
import time
import torch
from tqdm import tqdm

from megatron import get_retro_args, print_rank_0
from tools.retro.db.utils import get_indexed_dataset_infos
from tools.retro.external_libs import h5py


def get_index_dir():
    """Create sub-directory for this index."""

    args = get_retro_args()

    # Directory path.
    index_dir_path = os.path.join(
        args.retro_workdir,
        "index",
        args.retro_index_type,
        args.retro_index_str,
    )

    # Make directory.
    os.makedirs(index_dir_path, exist_ok=True)

    return index_dir_path


def num_samples_to_block_ranges(num_samples):
    '''Split a range (length num_samples) into sequence of block ranges
    of size block_size.'''
    args = get_retro_args()
    block_size = args.retro_block_size
    start_idxs = list(range(0, num_samples, block_size))
    end_idxs = [min(num_samples, s + block_size) for s in start_idxs]
    ranges = list(zip(start_idxs, end_idxs))
    return ranges


liangjing's avatar
v1  
liangjing committed
48
def get_training_data_root_dir():
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
49
    args = get_retro_args()
liangjing's avatar
v1  
liangjing committed
50
    return os.path.join(args.retro_workdir, "index", "train_emb")
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
51
52


liangjing's avatar
v1  
liangjing committed
53
54
def get_training_data_block_dir():
    return os.path.join(get_training_data_root_dir(), "blocks")
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
55
56


liangjing's avatar
v1  
liangjing committed
57
58
def get_training_data_block_paths():
    return sorted(glob.glob(get_training_data_block_dir() + "/*.hdf5"))
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
59
60


liangjing's avatar
v1  
liangjing committed
61
62
63
64
def get_training_data_merged_path():
    args = get_retro_args()
    return os.path.join(get_training_data_root_dir(),
                        "train_%.3f.bin" % args.retro_index_train_load_fraction)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
65
66


liangjing's avatar
v1  
liangjing committed
67
68
def get_added_codes_dir():
    return os.path.join(get_index_dir(), "add_codes")
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
69
70


liangjing's avatar
v1  
liangjing committed
71
72
def get_added_code_paths():
    return sorted(glob.glob(get_added_codes_dir() + "/*.hdf5"))