"lm_eval/tasks/catalan_bench/paws_ca.yaml" did not exist on "ea17b98e3cac7e4e80f8ac2183d1837fcd90f5fc"
cli.py 11 KB
Newer Older
liangjing's avatar
liangjing committed
1
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
liangjing's avatar
v1  
liangjing committed
2
3
4
5

import json
import numpy as np
import os
liangjing's avatar
liangjing committed
6
7
8
9
10
11
import typing as T
from types import SimpleNamespace

from megatron.training.arguments import load_retro_config, parse_args, validate_args
from megatron.core.datasets.retro.db.dataset import DBDataset
from megatron.core.datasets.retro.db.utils import (
liangjing's avatar
v1  
liangjing committed
12
13
14
    get_indexed_dataset_infos as get_db_indexed_dataset_infos,
    get_merged_train_dataset as get_db_dataset,
)
liangjing's avatar
liangjing committed
15
16
17
18
19
from megatron.core.datasets.retro.query.retro_dataset import get_retro_datasets, RetroDataset
from megatron.training.global_vars import set_global_variables
from megatron.training.training import build_train_valid_test_datasets, update_train_iters
from pretrain_retro import train_valid_test_datasets_provider
from tools.retro.preprocess_data import get_tokenizers
liangjing's avatar
v1  
liangjing committed
20
21


liangjing's avatar
liangjing committed
22
def shorten_str(s: str, n: int) -> str:
liangjing's avatar
v1  
liangjing committed
23
    s = "\\n".join(s.splitlines())
liangjing's avatar
liangjing committed
24
    return s if len(s) <= n else "%s ... %s" % (s[: n // 2], s[-n // 2 :])
liangjing's avatar
v1  
liangjing committed
25
26
27
28


class retro:

liangjing's avatar
liangjing committed
29
    config = None
liangjing's avatar
v1  
liangjing committed
30
31
32
33
34
35

    ##############################################
    # initialize.
    ##############################################

    @classmethod
liangjing's avatar
liangjing committed
36
    def init(cls, project_dir: str) -> None:
liangjing's avatar
v1  
liangjing committed
37
38
        '''Initialize Megatron, tokenizers, and datasets.'''

liangjing's avatar
liangjing committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        # Megatron args.
        args = parse_args(extra_args_provider=None, ignore_unknown_args=False)
        args.retro_project_dir = project_dir
        args.micro_batch_size = 1
        args.num_layers = 1
        args.hidden_size = 1
        args.num_attention_heads = 1
        args.async_tensor_model_parallel_allreduce = False
        args.retro_add_retriever = True # for building RetroDataset
        validate_args(args)
        set_global_variables(args)
        update_train_iters(args)

        # Retro config.
        cls.config = load_retro_config(project_dir)
        cls.config.retro_project_dir = project_dir
        cls.config.retro_tokenizers = get_tokenizers(cls.config)

        # Chunk database dataset.
        cls.db_indexed_dataset_infos = get_db_indexed_dataset_infos(project_dir)
        cls.db_dataset = get_db_dataset(project_dir,
                                        cls.config.retro_gpt_chunk_length,
                                        cls.config.retro_tokenizers.gpt.eod)

        # Pretraining datasets.
        pt_train_ds, pt_valid_ds, pt_test_ds = build_train_valid_test_datasets(
            train_valid_test_datasets_provider)
        cls.pt_datasets = SimpleNamespace(
liangjing's avatar
v1  
liangjing committed
67
68
            train=pt_train_ds,
            valid=pt_valid_ds,
liangjing's avatar
liangjing committed
69
            test=pt_test_ds,
liangjing's avatar
v1  
liangjing committed
70
71
72
73
74
75
76
77
78
79
        )

        # Print usage.
        cls.print_usage()

    ##############################################
    # utils.
    ##############################################

    @classmethod
liangjing's avatar
liangjing committed
80
    def gpt_to_text(cls, token_ids: np.ndarray) -> str:
liangjing's avatar
v1  
liangjing committed
81
        '''GPT tokens to text.'''
liangjing's avatar
liangjing committed
82
83
84
        return cls.config.retro_tokenizers.gpt.detokenize(
            token_ids.tolist() if isinstance(token_ids, np.ndarray) else token_ids
        )
liangjing's avatar
v1  
liangjing committed
85
86

    @classmethod
liangjing's avatar
liangjing committed
87
    def text_to_bert(cls, text: str) -> np.ndarray:
liangjing's avatar
v1  
liangjing committed
88
        '''Text to Bert tokens.'''
liangjing's avatar
liangjing committed
89
        return cls.config.retro_tokenizers.bert.tokenize(text)
liangjing's avatar
v1  
liangjing committed
90
91
92
93
94
95

    ##############################################
    # chunk db.
    ##############################################

    @classmethod
liangjing's avatar
liangjing committed
96
97
    def get_db_num_indexed_datasets(cls) -> int:
        '''Number of indexed datasets within blended dataset.'''
liangjing's avatar
v1  
liangjing committed
98
99
100
        return len(cls.db_indexed_dataset_infos)

    @classmethod
liangjing's avatar
liangjing committed
101
    def get_db_indexed_dataset_infos(cls) -> T.List[T.Tuple[float, str]]:
liangjing's avatar
v1  
liangjing committed
102
        '''Dataset infos, including number of training & sampled sets.'''
liangjing's avatar
liangjing committed
103
        return [(info["ratio"], info["prefix"]) for info in cls.db_indexed_dataset_infos]
liangjing's avatar
v1  
liangjing committed
104
105

    @classmethod
liangjing's avatar
liangjing committed
106
    def get_db_dataset(cls) -> DBDataset:
liangjing's avatar
v1  
liangjing committed
107
108
109
        return cls.db_dataset

    @classmethod
liangjing's avatar
liangjing committed
110
    def get_db_num_chunks(cls) -> int:
liangjing's avatar
v1  
liangjing committed
111
112
113
114
        '''Number of DB chunks.'''
        return len(cls.get_db_dataset())

    @classmethod
liangjing's avatar
liangjing committed
115
    def get_db_chunk_gpt(cls, idx: int) -> T.List[int]:
liangjing's avatar
v1  
liangjing committed
116
117
118
119
        '''Get DB chunk as GPT token ids.'''
        return cls.get_db_dataset()[idx]["text"].tolist()

    @classmethod
liangjing's avatar
liangjing committed
120
    def get_db_chunk_bert(cls, idx: int) -> T.List[int]:
liangjing's avatar
v1  
liangjing committed
121
122
123
124
        '''Get DB chunk as Bert token ids.'''
        return cls.text_to_bert(cls.get_db_chunk_text(idx))

    @classmethod
liangjing's avatar
liangjing committed
125
    def get_db_chunk_text(cls, idx: int) -> str:
liangjing's avatar
v1  
liangjing committed
126
127
128
129
        '''Get DB chunk as text.'''
        return cls.gpt_to_text(cls.get_db_chunk_gpt(idx))

    @classmethod
liangjing's avatar
liangjing committed
130
    def get_db_chunk_and_continuation_text(cls, idx: int) -> T.List[str]:
liangjing's avatar
v1  
liangjing committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
        '''Get DB chunk along with continuation, as text.'''

        # Modulus used here to match original implementation (i.e., last
        # chunks continuation wraps around to first chunk).
        return [
            cls.get_db_chunk_text(idx),
            cls.get_db_chunk_text((idx + 1) % len(cls.get_db_dataset())),
        ]

    ##############################################
    # pretraining corpus.
    ##############################################

    @classmethod
liangjing's avatar
liangjing committed
145
    def get_pt_num_samples_and_chunks(cls, data_key: str) -> T.Tuple[int, int]:
liangjing's avatar
v1  
liangjing committed
146
        '''Number of samples & chunks (e.g., 32*n_samples) in corpus.'''
liangjing's avatar
liangjing committed
147
148
149
150
        assert hasattr(cls.pt_datasets, data_key), (
            "pretraining set '%s' not found (choices: %s)."
            % (data_key, ", ".join(vars(cls.pt_datasets).keys()))
        )
liangjing's avatar
v1  
liangjing committed
151
152
153
154
155
156
157
        chunk_dataset = getattr(cls.pt_datasets, data_key).chunk_dataset
        return (
            len(chunk_dataset.sample_dataset),
            len(chunk_dataset),
        )

    @classmethod
liangjing's avatar
liangjing committed
158
    def get_pt_num_samples(cls, data_key: str) -> int:
liangjing's avatar
v1  
liangjing committed
159
160
161
162
        '''Number of pretraining samples.'''
        return cls.get_pt_num_samples_and_chunks(data_key)[0]

    @classmethod
liangjing's avatar
liangjing committed
163
    def get_pt_num_chunks(cls, data_key: str) -> int:
liangjing's avatar
v1  
liangjing committed
164
165
166
167
        '''Number of pretraining chunks (e.g., 32*n_samples).'''
        return cls.get_pt_num_samples_and_chunks(data_key)[1]

    @classmethod
liangjing's avatar
liangjing committed
168
    def get_pt_dataset(cls, data_key: str) -> RetroDataset:
liangjing's avatar
v1  
liangjing committed
169
170
171
        return getattr(cls.pt_datasets, data_key)

    @classmethod
liangjing's avatar
liangjing committed
172
    def get_pt_sample(cls, data_key: str, idx: int) -> dict:
liangjing's avatar
v1  
liangjing committed
173
174
175
        return getattr(cls.pt_datasets, data_key)[idx]

    @classmethod
liangjing's avatar
liangjing committed
176
    def get_neighbor_tokens(cls, sample_id: int, chunk_id: int, data_key: str="train") -> T.Optional[dict]:
liangjing's avatar
v1  
liangjing committed
177
178
179
180
181
        try:
            sample = cls.get_pt_sample(data_key, sample_id)
            sample_token_ids = sample["text"]
            chunk_length = cls.args.retro_gpt_chunk_length
            chunk_start_idx = chunk_id * chunk_length
liangjing's avatar
liangjing committed
182
            chunk_end_idx = min(sample_token_ids.shape[0], chunk_start_idx + chunk_length)
liangjing's avatar
v1  
liangjing committed
183
184
185
            chunk_token_ids = sample_token_ids[chunk_start_idx:chunk_end_idx]
            neighbor_token_ids = sample["neighbor_tokens"][chunk_id]
            return {
liangjing's avatar
liangjing committed
186
187
                "chunk_tokens": chunk_token_ids,
                "neighbor_tokens": neighbor_token_ids,
liangjing's avatar
v1  
liangjing committed
188
            }
liangjing's avatar
liangjing committed
189
        except Exception:
liangjing's avatar
v1  
liangjing committed
190
191
192
            return None

    @classmethod
liangjing's avatar
liangjing committed
193
194
    def print_neighbor_texts(cls, sample_id: int, chunk_id: int, data_key: str="train") -> None:
        tokens: dict = cls.get_neighbor_tokens(sample_id, chunk_id, data_key)
liangjing's avatar
v1  
liangjing committed
195
196
197
198
199
200
201
        print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
        try:
            print("PRETRAINING CHUNK:")
            print("  - %s" % shorten_str(cls.gpt_to_text(tokens["chunk_tokens"]), 150))
            print("NEIGHBOR_CHUNKS:")
            for token_ids in tokens["neighbor_tokens"]:
                print("  - %s" % shorten_str(cls.gpt_to_text(token_ids), 150))
liangjing's avatar
liangjing committed
202
        except Exception:
liangjing's avatar
v1  
liangjing committed
203
204
205
206
207
208
209
            print("<no neighbors for sample %d>" % sample_id)

    ##############################################
    # usage.
    ##############################################

    @classmethod
liangjing's avatar
liangjing committed
210
    def print_usage(cls) -> None:
liangjing's avatar
v1  
liangjing committed
211
212
213
214
215
216
217
218
219
        '''Print usage.'''

        print()
        print("+++++++++++++++++++++++++++++++++++++++++++++++++++")
        print("examples ... [ *note*: 'db' = chunk db; 'pt' = pretraining corpus. ]")
        print("+++++++++++++++++++++++++++++++++++++++++++++++++++")

        print()
        print("~~~~ indexed datasets ~~~~")
liangjing's avatar
liangjing committed
220
        print("retro.get_db_num_indexed_datasets() : %s" % cls.get_db_num_indexed_datasets())
liangjing's avatar
v1  
liangjing committed
221
        print("retro.get_db_indexed_dataset_infos() :")
liangjing's avatar
liangjing committed
222
223
224
225
226
227
228
229
230
231
        for i, (ratio, prefix) in enumerate(cls.get_db_indexed_dataset_infos()):
            print(
                "  %s(%f, %s)%s"
                % (
                    "[" if i == 0 else " ",
                    ratio,
                    prefix,
                    "]" if i == len(cls.db_indexed_dataset_infos) - 1 else ",",
                )
            )
liangjing's avatar
v1  
liangjing committed
232
233
234
235
236
237
238

        print()
        print("~~~~ counts ~~~~")
        print("retro.get_db_num_chunks : %d." % cls.get_db_num_chunks())

        print()
        for sq_key in ("sample", "chunk"):
liangjing's avatar
liangjing committed
239
240
241
242
243
            for data_key in ("train", "valid"):  # test?
                print(
                    "retro.get_pt_num_%ss('%s') : %d."
                    % (sq_key, data_key, getattr(cls, f"get_pt_num_{sq_key}s")(data_key))
                )
liangjing's avatar
v1  
liangjing committed
244
245
246

        print()
        print("~~~~ tokens, text ~~~~")
liangjing's avatar
liangjing committed
247
248
249
250
251
252
253
254
255
256
257
258
        print(
            "retro.get_db_chunk_gpt(chunk_id) : %s"
            % shorten_str(str(retro.get_db_chunk_gpt(0)), 50)
        )
        print(
            "retro.get_db_chunk_bert(chunk_id) : %s"
            % shorten_str(str(retro.get_db_chunk_bert(0)), 50)
        )
        print(
            "retro.get_db_chunk_text(chunk_id) : %s"
            % shorten_str(retro.get_db_chunk_text(0).strip(), 50)
        )
liangjing's avatar
v1  
liangjing committed
259
260
        print("retro.get_db_chunk_and_continuation_text(chunk_id) :")
        for i, t in enumerate(retro.get_db_chunk_and_continuation_text(0)):
liangjing's avatar
liangjing committed
261
262
263
264
265
266
267
268
            print(
                "  %s'%s'%s"
                % (
                    "[" if i == 0 else " ",
                    shorten_str(t.strip().replace("\n", " "), 50),
                    "]" if i == 1 else ",",
                )
            )
liangjing's avatar
v1  
liangjing committed
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285

        sample = cls.get_pt_sample("train", 0)
        sample_chunk_id = sample["neighbor_tokens"].shape[0] // 2
        sample_neighbor_id = 0
        print()
        print("retro.get_pt_sample('train', sample_id) :")
        print("  {")
        for k, v in sample.items():
            print("    '%s' : %s" % (k, shorten_str(str(v), 50)))
        print("  }")

        print()
        print("(e.g., sample = retro.get_pt_sample(...))")
        print()
        print("  sample['text'].shape : %s" % str(sample["text"].shape))
        print("  sample['neighbor_tokens'].shape : %s" % str(sample["neighbor_tokens"].shape))
        print("  sample['text'] : %s" % shorten_str(str(sample["text"]), 50))
liangjing's avatar
liangjing committed
286
287
288
289
290
291
292
293
294
295
296
297
298
299
        print(
            "  sample['neighbor_tokens'][17][1] : %s"
            % shorten_str(str(sample["neighbor_tokens"][sample_chunk_id][sample_neighbor_id]), 50)
        )
        print(
            "  retro.gpt_to_text(sample['text']) : %s"
            % shorten_str(cls.gpt_to_text(sample["text"]), 50)
        )
        print(
            "  retro.gpt_to_text(sample['neighbor_tokens']) : %s"
            % shorten_str(
                cls.gpt_to_text(sample["neighbor_tokens"][sample_chunk_id][sample_neighbor_id]), 50
            )
        )
liangjing's avatar
v1  
liangjing committed
300
301

        print("+++++++++++++++++++++++++++++++++++++++++++++++++++")