"examples/multimodal/vscode:/vscode.git/clone" did not exist on "f77511ff2cab4e3337b9d966a927ab1e13c637ab"
builder.py 5.76 KB
Newer Older
1
2
3
from colossalai.logging import get_dist_logger

from .bert_dataset import BertDataset
4
5
6
from .blendable_dataset import BlendableDataset
from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_

7
8
9
DSET_TYPE_BERT = "standard_bert"
DSET_TYPE_ICT = "ict"
DSET_TYPE_T5 = "t5"
10
11
12
13

DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5]


14
15
16
17
18
19
20
21
22
23
24
25
26
def _build_train_valid_test_datasets(
    data_prefix,
    data_impl,
    splits_string,
    train_valid_test_num_samples,
    max_seq_length,
    masked_lm_prob,
    short_seq_prob,
    seed,
    skip_warmup,
    binary_head,
    dataset_type="standard_bert",
):
27
28
29
30
    if dataset_type not in DSET_TYPES:
        raise ValueError("Invalid dataset_type: ", dataset_type)

    # Indexed dataset.
31
    indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup)
32
33
34
35
36
37
38
39
40
41

    # Get start and end indices of train/valid/train into doc-idx
    # Note that doc-idx is designed to be num-docs + 1 so we can
    # easily iterate over it.
    total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
    splits = get_train_valid_test_split_(splits_string, total_num_of_documents)

    logger = get_dist_logger()

    # Print stats about the splits.
42
    logger.info("\n > dataset split:", ranks=[0])
43
44
45
46

    def print_split_stats(name, index):
        start_index = indexed_dataset.doc_idx[splits[index]]
        end_index = indexed_dataset.doc_idx[splits[index + 1]]
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        logger.info(
            "\n    {}:".format(name)
            + "\n     document indices in [{}, {}) total of {} documents".format(
                splits[index], splits[index + 1], splits[index + 1] - splits[index]
            )
            + "\n     sentence indices in [{}, {}) total of {} sentences".format(
                start_index, end_index, end_index - start_index
            ),
            ranks=[0],
        )

    print_split_stats("train", 0)
    print_split_stats("validation", 1)
    print_split_stats("test", 2)
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

    def build_dataset(index, name):
        dataset = None
        if splits[index + 1] > splits[index]:
            # Get the pointer to the original doc-idx so we can set it later.
            doc_idx_ptr = indexed_dataset.get_doc_idx()
            # Slice the doc-idx
            start_index = splits[index]
            # Add +1 so we can index into the dataset to get the upper bound.
            end_index = splits[index + 1] + 1
            # New doc_idx view.
            indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
            # Build the dataset accordingly.
            kwargs = dict(
                name=name,
                data_prefix=data_prefix,
                num_epochs=None,
                max_num_samples=train_valid_test_num_samples[index],
                max_seq_length=max_seq_length,
                seed=seed,
            )

            if dataset_type != DSET_TYPE_BERT:
                raise NotImplementedError("Only BERT dataset is supported")
            else:
                dataset = BertDataset(
                    indexed_dataset=indexed_dataset,
                    masked_lm_prob=masked_lm_prob,
                    short_seq_prob=short_seq_prob,
                    binary_head=binary_head,
91
                    **kwargs,
92
93
94
95
96
97
                )

            # Set the original pointer so dataset remains the main dataset.
            indexed_dataset.set_doc_idx(doc_idx_ptr)
            # Checks.
            assert indexed_dataset.doc_idx[0] == 0
98
            assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1)
99
100
        return dataset

101
102
103
    train_dataset = build_dataset(0, "train")
    valid_dataset = build_dataset(1, "valid")
    test_dataset = build_dataset(2, "test")
104
105
106
107

    return (train_dataset, valid_dataset, test_dataset)


108
109
110
111
112
113
114
115
116
117
118
119
120
def build_train_valid_test_datasets(
    data_prefix,
    data_impl,
    splits_string,
    train_valid_test_num_samples,
    max_seq_length,
    masked_lm_prob,
    short_seq_prob,
    seed,
    skip_warmup,
    binary_head,
    dataset_type="standard_bert",
):
121
    if len(data_prefix) == 1:
122
123
124
125
126
127
128
129
130
131
132
133
134
        return _build_train_valid_test_datasets(
            data_prefix[0],
            data_impl,
            splits_string,
            train_valid_test_num_samples,
            max_seq_length,
            masked_lm_prob,
            short_seq_prob,
            seed,
            skip_warmup,
            binary_head,
            dataset_type=dataset_type,
        )
135
136
    # Blending dataset.
    # Parse the values.
137
    output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples)
138
139
140
141
142
143
144
145
    prefixes, weights, datasets_train_valid_test_num_samples = output

    # Build individual datasets.
    train_datasets = []
    valid_datasets = []
    test_datasets = []
    for i in range(len(prefixes)):
        train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(
146
147
148
            prefixes[i],
            data_impl,
            splits_string,
149
            datasets_train_valid_test_num_samples[i],
150
151
152
153
154
155
156
157
            max_seq_length,
            masked_lm_prob,
            short_seq_prob,
            seed,
            skip_warmup,
            binary_head,
            dataset_type=dataset_type,
        )
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        if train_ds:
            train_datasets.append(train_ds)
        if valid_ds:
            valid_datasets.append(valid_ds)
        if test_ds:
            test_datasets.append(test_ds)

        # Blend.
    blending_train_dataset = None
    if train_datasets:
        blending_train_dataset = BlendableDataset(train_datasets, weights)
    blending_valid_dataset = None
    if valid_datasets:
        blending_valid_dataset = BlendableDataset(valid_datasets, weights)
    blending_test_dataset = None
    if test_datasets:
        blending_test_dataset = BlendableDataset(test_datasets, weights)

176
    return (blending_train_dataset, blending_valid_dataset, blending_test_dataset)