chunk_dataset.py 4.34 KB
Newer Older
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
1
2
3
4
5
6
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.

import os
import torch

from megatron import get_retro_args, print_rank_0
liangjing's avatar
v1  
liangjing committed
7
8
from megatron.data.gpt_dataset import build_train_valid_test_datasets \
    as build_gpt_train_valid_test_datasets
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
9
from megatron.training import (
liangjing's avatar
v1  
liangjing committed
10
    build_train_valid_test_datasets as build_pretraining_train_valid_test_datasets,
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
11
12
13
14
15
    update_train_iters,
)
from tools.retro.db.utils import get_indexed_dataset_infos
from tools.retro.utils import get_num_chunks_per_sample

liangjing's avatar
v1  
liangjing committed
16
from .utils import get_neighbor_dirname, get_query_workdir
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89


class ChunkDataset(torch.utils.data.Dataset):
    '''Pretraining chunk dataset wraps a standard GPT dataset.

    This dataset conceptually divides each sample (e.g., length 2048)
    into chunks (e.g., length 64) and restructures them into a list of
    chunks (e.g., length num_samples * num_chunks_per_sample).
    '''

    def __init__(self, sample_dataset, chunk_length):

        super().__init__()

        self.sample_dataset = sample_dataset

        self.chunk_length = chunk_length
        self.n_chunks_per_sample = get_num_chunks_per_sample()
        self.n_samples = len(sample_dataset)
        self.n_chunks = self.n_samples * self.n_chunks_per_sample

    def __len__(self):
        return self.n_chunks

    def __getitem__(self, idx):

        # Convert global chunk index to global sample index & local chunk index.
        sample_idx = idx // self.n_chunks_per_sample
        chunk_idx = idx % self.n_chunks_per_sample

        # Extract sample data.
        sample = self.sample_dataset[sample_idx]
        sample_token_ids = sample["text"]
        sample_doc_ids = sample["doc_ids"]

        # Chunk start/end token idxs.
        token_start_idx = chunk_idx * self.chunk_length
        token_end_idx = token_start_idx + self.chunk_length
        chunk_token_ids = sample_token_ids[token_start_idx:token_end_idx]

        # Sample.
        return {
            "doc_ids" : sample_doc_ids,
            "text" : chunk_token_ids,
        }


def verify_indexed_dataset_order():
    '''Verify pretraining order same as DB order.'''

    args = get_retro_args()

    # DB dataset prefixes.
    db_indexed_dataset_infos = get_indexed_dataset_infos()
    db_prefixes = [ info["prefix"] for info in db_indexed_dataset_infos ]

    # Verify order & prefixes.
    assert len(args.data_path) >= 2, "blendable dataset supported only."
    pretraining_prefixes = args.data_path[1:None:2]

    if len(db_prefixes) != len(pretraining_prefixes):
        raise Exception("inconsistent dataset count between db & pretraining.")
    if db_prefixes != pretraining_prefixes:
        raise Exception("inconsistent dataset order between db & pretraining.")


def train_valid_test_datasets_provider(train_val_test_num_samples):
    """Build train, valid, and test datasets."""

    args = get_retro_args()

    print_rank_0('> building train, validation, and test datasets '
                 'for GPT ...')
liangjing's avatar
v1  
liangjing committed
90
91
92
93
    train_ds, valid_ds, test_ds = build_gpt_train_valid_test_datasets(
        data_prefix=args.retro_gpt_data_path,
        data_impl=args.retro_gpt_data_impl,
        splits_string=args.retro_gpt_split,
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
94
95
        train_valid_test_num_samples=train_val_test_num_samples,
        seq_length=args.retro_gpt_seq_length,
liangjing's avatar
v1  
liangjing committed
96
97
        seed=args.retro_gpt_seed,
        skip_warmup=(not args.retro_gpt_mmap_warmup),
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        return_doc_ids=args.retro_return_doc_ids)
    print_rank_0("> finished creating pretrained GPT datasets ...")

    return train_ds, valid_ds, test_ds


def get_chunk_dataset_map():
    '''Get train, valid, test chunk datasets.'''

    args = get_retro_args()

    # Update train iters.
    update_train_iters(args)

    args.iteration = 0
    args.consumed_train_samples = 0

    # Verify indexed dataset order.
    verify_indexed_dataset_order()

    # Datasets.
liangjing's avatar
v1  
liangjing committed
119
120
121
122
123
124
125
126
    print_rank_0(" > datasets.")
    train_ds, valid_ds, test_ds = build_pretraining_train_valid_test_datasets(
        train_valid_test_datasets_provider)

    sample_dataset_map = {
        "train" : train_ds,
        "valid" : valid_ds,
        "test" : test_ds,
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
127
128
129
    }

    # Info dict.
liangjing's avatar
v1  
liangjing committed
130
    chunk_dataset_map = {
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
131
        key : {
liangjing's avatar
v1  
liangjing committed
132
133
            "neighbor_dir" : get_neighbor_dirname(key, sample_ds),
            "data" : ChunkDataset(sample_ds, args.retro_gpt_chunk_length),
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
134
        }
liangjing's avatar
v1  
liangjing committed
135
        for key, sample_ds in sample_dataset_map.items() if sample_ds
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
136
137
    }

liangjing's avatar
v1  
liangjing committed
138
    return chunk_dataset_map