run_clm_flax.py 37.2 KB
Newer Older
Suraj Patil's avatar
Suraj Patil committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Pre-training/Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.

Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
20
https://huggingface.co/models?filter=text-generation
Suraj Patil's avatar
Suraj Patil committed
21
22
23
"""
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.

Suraj Patil's avatar
Suraj Patil committed
24
import json
Suraj Patil's avatar
Suraj Patil committed
25
26
27
28
29
import logging
import math
import os
import sys
import time
30
import warnings
31
32
from dataclasses import asdict, dataclass, field
from enum import Enum
33
from itertools import chain
Suraj Patil's avatar
Suraj Patil committed
34
35
36
37
38
39
from pathlib import Path
from typing import Callable, Optional

import datasets
import jax
import jax.numpy as jnp
40
import numpy as np
Suraj Patil's avatar
Suraj Patil committed
41
import optax
42
from datasets import Dataset, load_dataset
Suraj Patil's avatar
Suraj Patil committed
43
from flax import jax_utils, traverse_util
44
from flax.jax_utils import pad_shard_unpad, unreplicate
Suraj Patil's avatar
Suraj Patil committed
45
46
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
47
from huggingface_hub import Repository, create_repo
48
49
50
from tqdm import tqdm

import transformers
Suraj Patil's avatar
Suraj Patil committed
51
52
53
54
55
56
57
58
from transformers import (
    CONFIG_MAPPING,
    FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
    AutoConfig,
    AutoTokenizer,
    FlaxAutoModelForCausalLM,
    HfArgumentParser,
    is_tensorboard_available,
59
    set_seed,
Suraj Patil's avatar
Suraj Patil committed
60
61
)
from transformers.testing_utils import CaptureLogger
62
from transformers.utils import send_example_telemetry
Suraj Patil's avatar
Suraj Patil committed
63
64
65
66
67
68
69
70


logger = logging.getLogger(__name__)

MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
@dataclass
class TrainingArguments:
    output_dir: str = field(
        metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
    )
    overwrite_output_dir: bool = field(
        default=False,
        metadata={
            "help": (
                "Overwrite the content of the output directory. "
                "Use this to continue training if output_dir points to a checkpoint directory."
            )
        },
    )
    do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
    do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
    per_device_train_batch_size: int = field(
        default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
    )
    per_device_eval_batch_size: int = field(
        default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
    )
    learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
    weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
    adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
    adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
    adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
    adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
    num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
    warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
    logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
    save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
    eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
    seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
    push_to_hub: bool = field(
        default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
    )
    hub_model_id: str = field(
        default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
    )
    hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})

    def __post_init__(self):
        if self.output_dir is not None:
            self.output_dir = os.path.expanduser(self.output_dir)

    def to_dict(self):
        """
        Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
        the token values by removing their value.
        """
        d = asdict(self)
        for k, v in d.items():
            if isinstance(v, Enum):
                d[k] = v.value
            if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
                d[k] = [x.value for x in v]
            if k.endswith("_token"):
                d[k] = f"<{k.upper()}>"
        return d


Suraj Patil's avatar
Suraj Patil committed
133
134
135
136
137
138
139
140
141
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """

    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
142
            "help": (
143
                "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
Sylvain Gugger's avatar
Sylvain Gugger committed
144
            )
Suraj Patil's avatar
Suraj Patil committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
        },
    )
    model_type: Optional[str] = field(
        default=None,
        metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    dtype: Optional[str] = field(
        default="float32",
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
167
168
169
170
            "help": (
                "Floating-point format in which the model weights should be initialized and trained. Choose one of"
                " `[float32, float16, bfloat16]`."
            )
Suraj Patil's avatar
Suraj Patil committed
171
172
        },
    )
173
174
    token: str = field(
        default=None,
175
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
176
            "help": (
177
178
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
Sylvain Gugger's avatar
Sylvain Gugger committed
179
            )
180
181
        },
    )
182
183
184
    use_auth_token: bool = field(
        default=None,
        metadata={
185
            "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
186
187
        },
    )
188
189
190
191
192
    trust_remote_code: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
193
                "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
194
195
196
197
                "execute code present on the Hub on your local machine."
            )
        },
    )
Suraj Patil's avatar
Suraj Patil committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
    validation_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
220
221
222
223
            "help": (
                "For debugging purposes or quicker training, truncate the number of training examples to this "
                "value if set."
            )
Suraj Patil's avatar
Suraj Patil committed
224
225
226
227
228
        },
    )
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
229
230
231
232
            "help": (
                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
                "value if set."
            )
Suraj Patil's avatar
Suraj Patil committed
233
234
235
236
237
238
239
240
241
242
243
244
245
246
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    validation_split_percentage: Optional[int] = field(
        default=5,
        metadata={
            "help": "The percentage of the train set used as validation set in case there's no validation split"
        },
    )
    block_size: Optional[int] = field(
        default=None,
        metadata={
Sylvain Gugger's avatar
Sylvain Gugger committed
247
248
249
250
251
            "help": (
                "Optional input sequence length after tokenization. "
                "The training dataset will be truncated in block of this size for training. "
                "Default to the model max input length for single sentence inputs (take into account special tokens)."
            )
Suraj Patil's avatar
Suraj Patil committed
252
253
254
255
256
257
258
259
260
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
261
    keep_linebreaks: bool = field(
262
        default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
263
    )
Suraj Patil's avatar
Suraj Patil committed
264
265
266
267
268
269
270

    def __post_init__(self):
        if self.dataset_name is None and self.train_file is None and self.validation_file is None:
            raise ValueError("Need either a dataset name or a training/validation file.")
        else:
            if self.train_file is not None:
                extension = self.train_file.split(".")[-1]
271
272
                if extension not in ["csv", "json", "txt"]:
                    raise ValueError("train_file` should be a csv, json or text file.")
Suraj Patil's avatar
Suraj Patil committed
273
274
            if self.validation_file is not None:
                extension = self.validation_file.split(".")[-1]
275
276
                if extension not in ["csv", "json", "txt"]:
                    raise ValueError("`validation_file` should be a csv, json or text file.")
Suraj Patil's avatar
Suraj Patil committed
277
278
279
280
281
282
283
284
285


class TrainState(train_state.TrainState):
    dropout_rng: jnp.ndarray

    def replicate(self):
        return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))


286
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
Suraj Patil's avatar
Suraj Patil committed
287
    """
288
289
    Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
    and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
Suraj Patil's avatar
Suraj Patil committed
290
291
    """
    if shuffle:
292
        batch_idx = jax.random.permutation(rng, len(dataset))
293
        batch_idx = np.asarray(batch_idx)
Suraj Patil's avatar
Suraj Patil committed
294
    else:
295
        batch_idx = np.arange(len(dataset))
Suraj Patil's avatar
Suraj Patil committed
296

297
298
299
300
301
302
303
    if drop_last:
        steps_per_epoch = len(dataset) // batch_size
        batch_idx = batch_idx[: steps_per_epoch * batch_size]  # Skip incomplete batch.
        batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
    else:
        steps_per_epoch = math.ceil(len(dataset) / batch_size)
        batch_idx = np.array_split(batch_idx, steps_per_epoch)
Suraj Patil's avatar
Suraj Patil committed
304
305
306

    for idx in batch_idx:
        batch = dataset[idx]
307
        batch = {k: np.array(v) for k, v in batch.items()}
Suraj Patil's avatar
Suraj Patil committed
308
309
310
311

        yield batch


312
def write_train_metric(summary_writer, train_metrics, train_time, step):
Suraj Patil's avatar
Suraj Patil committed
313
314
315
316
317
318
319
320
    summary_writer.scalar("train_time", train_time, step)

    train_metrics = get_metrics(train_metrics)
    for key, vals in train_metrics.items():
        tag = f"train_{key}"
        for i, val in enumerate(vals):
            summary_writer.scalar(tag, val, step - len(vals) + i + 1)

321
322

def write_eval_metric(summary_writer, eval_metrics, step):
Suraj Patil's avatar
Suraj Patil committed
323
324
325
326
327
328
    for metric_name, value in eval_metrics.items():
        summary_writer.scalar(f"eval_{metric_name}", value, step)


def create_learning_rate_fn(
    train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
329
) -> Callable[[int], jnp.ndarray]:
Suraj Patil's avatar
Suraj Patil committed
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
    """Returns a linear warmup, linear_decay learning rate function."""
    steps_per_epoch = train_ds_size // train_batch_size
    num_train_steps = steps_per_epoch * num_train_epochs
    warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
    decay_fn = optax.linear_schedule(
        init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
    )
    schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
    return schedule_fn


def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

354
    if model_args.use_auth_token is not None:
355
356
357
358
        warnings.warn(
            "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
            FutureWarning,
        )
359
360
361
362
        if model_args.token is not None:
            raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
        model_args.token = model_args.use_auth_token

363
364
365
366
    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
    send_example_telemetry("run_clm", model_args, data_args, framework="flax")

Suraj Patil's avatar
Suraj Patil committed
367
368
369
370
371
372
373
    if (
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(
374
            f"Output directory ({training_args.output_dir}) already exists and is not empty. "
Suraj Patil's avatar
Suraj Patil committed
375
376
377
378
379
            "Use --overwrite_output_dir to overcome."
        )

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
380
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Suraj Patil's avatar
Suraj Patil committed
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    # Set the verbosity to info of the Transformers logger (on main process only):
    logger.info(f"Training/evaluation parameters {training_args}")

396
397
398
    # Set seed before initializing model.
    set_seed(training_args.seed)

399
400
    # Handle the repository creation
    if training_args.push_to_hub:
401
402
403
404
405
406
407
408
        # Retrieve of infer repo_name
        repo_name = training_args.hub_model_id
        if repo_name is None:
            repo_name = Path(training_args.output_dir).absolute().name
        # Create repo and retrieve repo_id
        repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
        # Clone repo locally
        repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
409

Suraj Patil's avatar
Suraj Patil committed
410
411
412
413
414
415
416
417
418
419
420
421
    #  Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
    # 'text' is found. You can easily tweak this behavior (see below).
    #
    # In distributed training, the load_dataset function guarantees that only one local process can concurrently
    # download the dataset.
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        dataset = load_dataset(
422
423
424
425
            data_args.dataset_name,
            data_args.dataset_config_name,
            cache_dir=model_args.cache_dir,
            keep_in_memory=False,
426
            token=model_args.token,
427
            num_proc=data_args.preprocessing_num_workers,
Suraj Patil's avatar
Suraj Patil committed
428
429
430
431
432
433
434
435
        )

        if "validation" not in dataset.keys():
            dataset["validation"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
436
                token=model_args.token,
437
                num_proc=data_args.preprocessing_num_workers,
Suraj Patil's avatar
Suraj Patil committed
438
439
440
441
442
443
            )
            dataset["train"] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
444
                token=model_args.token,
445
                num_proc=data_args.preprocessing_num_workers,
Suraj Patil's avatar
Suraj Patil committed
446
447
448
            )
    else:
        data_files = {}
449
        dataset_args = {}
Suraj Patil's avatar
Suraj Patil committed
450
451
452
453
454
455
456
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        extension = data_args.train_file.split(".")[-1]
        if extension == "txt":
            extension = "text"
457
            dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
458
459
460
461
462
        dataset = load_dataset(
            extension,
            data_files=data_files,
            cache_dir=model_args.cache_dir,
            **dataset_args,
463
            token=model_args.token,
464
            num_proc=data_args.preprocessing_num_workers,
465
        )
466

467
468
        if "validation" not in dataset.keys():
            dataset["validation"] = load_dataset(
469
470
471
472
                extension,
                data_files=data_files,
                split=f"train[:{data_args.validation_split_percentage}%]",
                cache_dir=model_args.cache_dir,
473
                **dataset_args,
474
                token=model_args.token,
475
                num_proc=data_args.preprocessing_num_workers,
476
            )
477
            dataset["train"] = load_dataset(
478
479
480
481
                extension,
                data_files=data_files,
                split=f"train[{data_args.validation_split_percentage}%:]",
                cache_dir=model_args.cache_dir,
482
                **dataset_args,
483
                token=model_args.token,
484
                num_proc=data_args.preprocessing_num_workers,
485
            )
Suraj Patil's avatar
Suraj Patil committed
486
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
487
    # https://huggingface.co/docs/datasets/loading_datasets.
Suraj Patil's avatar
Suraj Patil committed
488
489
490
491
492
493
494

    # Load pretrained model and tokenizer

    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
    if model_args.config_name:
495
496
497
        config = AutoConfig.from_pretrained(
            model_args.config_name,
            cache_dir=model_args.cache_dir,
498
            token=model_args.token,
499
            trust_remote_code=model_args.trust_remote_code,
500
        )
Suraj Patil's avatar
Suraj Patil committed
501
    elif model_args.model_name_or_path:
502
503
504
        config = AutoConfig.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
505
            token=model_args.token,
506
            trust_remote_code=model_args.trust_remote_code,
507
        )
Suraj Patil's avatar
Suraj Patil committed
508
509
510
511
512
513
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
        logger.warning("You are instantiating a new config instance from scratch.")

    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(
514
515
516
            model_args.tokenizer_name,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer,
517
            token=model_args.token,
518
            trust_remote_code=model_args.trust_remote_code,
Suraj Patil's avatar
Suraj Patil committed
519
520
521
        )
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(
522
523
524
            model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            use_fast=model_args.use_fast_tokenizer,
525
            token=model_args.token,
526
            trust_remote_code=model_args.trust_remote_code,
Suraj Patil's avatar
Suraj Patil committed
527
528
529
        )
    else:
        raise ValueError(
530
            "You are instantiating a new tokenizer from scratch. This is not supported by this script. "
Suraj Patil's avatar
Suraj Patil committed
531
532
533
534
535
            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
        )

    if model_args.model_name_or_path:
        model = FlaxAutoModelForCausalLM.from_pretrained(
536
537
538
539
            model_args.model_name_or_path,
            config=config,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype),
540
            token=model_args.token,
541
            trust_remote_code=model_args.trust_remote_code,
Suraj Patil's avatar
Suraj Patil committed
542
543
544
        )
    else:
        model = FlaxAutoModelForCausalLM.from_config(
545
546
547
            config,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype),
548
            trust_remote_code=model_args.trust_remote_code,
Suraj Patil's avatar
Suraj Patil committed
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
        )

    # Preprocessing the datasets.
    # First we tokenize all the texts.
    if training_args.do_train:
        column_names = dataset["train"].column_names
    else:
        column_names = dataset["validation"].column_names
    text_column_name = "text" if "text" in column_names else column_names[0]

    # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
    tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")

    def tokenize_function(examples):
        with CaptureLogger(tok_logger) as cl:
            output = tokenizer(examples[text_column_name])
        # clm input could be much much longer than block_size
        if "Token indices sequence length is longer than the" in cl.out:
            tok_logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
568
569
                "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
                " before being passed to the model."
Suraj Patil's avatar
Suraj Patil committed
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
            )
        return output

    tokenized_datasets = dataset.map(
        tokenize_function,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    if data_args.block_size is None:
        block_size = tokenizer.model_max_length
        if block_size > config.max_position_embeddings:
            logger.warning(
                f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
586
                f"Using block_size={min(1024, config.max_position_embeddings)} instead. You can change that default value by passing --block_size xxx."
Suraj Patil's avatar
Suraj Patil committed
587
            )
588
            block_size = min(1024, config.max_position_embeddings)
Suraj Patil's avatar
Suraj Patil committed
589
590
591
    else:
        if data_args.block_size > tokenizer.model_max_length:
            logger.warning(
592
                f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model "
Suraj Patil's avatar
Suraj Patil committed
593
594
595
596
597
598
599
                f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
            )
        block_size = min(data_args.block_size, tokenizer.model_max_length)

    # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
    def group_texts(examples):
        # Concatenate all texts.
600
        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
Suraj Patil's avatar
Suraj Patil committed
601
602
603
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
604
605
        if total_length >= block_size:
            total_length = (total_length // block_size) * block_size
Suraj Patil's avatar
Suraj Patil committed
606
607
608
609
610
611
612
613
614
615
616
617
618
        # Split by chunks of max_len.
        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
    # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
    # to preprocess.
    #
    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
619
    # https://huggingface.co/docs/datasets/process#map
Suraj Patil's avatar
Suraj Patil committed
620
621
622
623
624
625
626
627
628
629
630
631
632

    lm_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        load_from_cache_file=not data_args.overwrite_cache,
    )

    if training_args.do_train:
        if "train" not in tokenized_datasets:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = lm_datasets["train"]
        if data_args.max_train_samples is not None:
633
634
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))
Suraj Patil's avatar
Suraj Patil committed
635
636
637
638
639
640

    if training_args.do_eval:
        if "validation" not in tokenized_datasets:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = lm_datasets["validation"]
        if data_args.max_eval_samples is not None:
641
642
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))
Suraj Patil's avatar
Suraj Patil committed
643
644

    # Enable tensorboard only on the master node
645
    has_tensorboard = is_tensorboard_available()
Suraj Patil's avatar
Suraj Patil committed
646
    if has_tensorboard and jax.process_index() == 0:
647
648
649
650
651
652
653
654
655
656
657
658
659
660
        try:
            from flax.metrics.tensorboard import SummaryWriter

            summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable."
        )
Suraj Patil's avatar
Suraj Patil committed
661
662
663
664
665
666
667
668

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed)
    rng, dropout_rng = jax.random.split(rng)

    # Store some constant
    num_epochs = int(training_args.num_train_epochs)
    train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
669
670
    per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
    eval_batch_size = per_device_eval_batch_size * jax.device_count()
Suraj Patil's avatar
Suraj Patil committed
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
    steps_per_epoch = len(train_dataset) // train_batch_size
    total_train_steps = steps_per_epoch * num_epochs

    # Create learning rate schedule
    linear_decay_lr_schedule_fn = create_learning_rate_fn(
        len(train_dataset),
        train_batch_size,
        training_args.num_train_epochs,
        training_args.warmup_steps,
        training_args.learning_rate,
    )

    # We use Optax's "masking" functionality to not apply weight decay
    # to bias and LayerNorm scale parameters. decay_mask_fn returns a
    # mask boolean with the same structure as the parameters.
    # The mask is True for parameters that should be decayed.
    def decay_mask_fn(params):
        flat_params = traverse_util.flatten_dict(params)
689
690
        # find out all LayerNorm parameters
        layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
691
692
693
694
695
696
        layer_norm_named_params = {
            layer[-2:]
            for layer_norm_name in layer_norm_candidates
            for layer in flat_params.keys()
            if layer_norm_name in "".join(layer).lower()
        }
697
        flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
Suraj Patil's avatar
Suraj Patil committed
698
699
700
        return traverse_util.unflatten_dict(flat_mask)

    # create adam optimizer
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
    if training_args.adafactor:
        # We use the default parameters here to initialize adafactor,
        # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
        optimizer = optax.adafactor(
            learning_rate=linear_decay_lr_schedule_fn,
        )
    else:
        optimizer = optax.adamw(
            learning_rate=linear_decay_lr_schedule_fn,
            b1=training_args.adam_beta1,
            b2=training_args.adam_beta2,
            eps=training_args.adam_epsilon,
            weight_decay=training_args.weight_decay,
            mask=decay_mask_fn,
        )
Suraj Patil's avatar
Suraj Patil committed
716
717

    # Setup train state
718
    state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
Suraj Patil's avatar
Suraj Patil committed
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772

    def loss_fn(logits, labels):
        shift_logits = logits[..., :-1, :]
        shift_labels = labels[..., 1:]
        loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
        return loss.mean()

    # Define gradient update step fn
    def train_step(state, batch):
        dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)

        def compute_loss(params):
            labels = batch.pop("labels")
            logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            loss = loss_fn(logits, labels)
            return loss

        grad_fn = jax.value_and_grad(compute_loss)
        loss, grad = grad_fn(state.params)
        grad = jax.lax.pmean(grad, "batch")

        new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)

        metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
        metrics = jax.lax.pmean(metrics, axis_name="batch")

        return new_state, metrics

    # Define eval fn
    def eval_step(params, batch):
        labels = batch.pop("labels")
        logits = model(**batch, params=params, train=False)[0]
        loss = loss_fn(logits, labels)

        # summarize metrics
        metrics = {"loss": loss}
        metrics = jax.lax.pmean(metrics, axis_name="batch")
        return metrics

    # Create parallel version of the train and eval step
    p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
    p_eval_step = jax.pmap(eval_step, "batch")

    # Replicate the train state on each device
    state = state.replicate()

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {num_epochs}")
    logger.info(f"  Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel & distributed) = {train_batch_size}")
    logger.info(f"  Total optimization steps = {total_train_steps}")

    train_time = 0
773
    train_metrics = []
774
    epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0)
Suraj Patil's avatar
Suraj Patil committed
775
776
777
778
779
780
781
782
783
784
785
    for epoch in epochs:
        # ======================== Training ================================
        train_start = time.time()

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)

        # Generate an epoch by shuffling sampling indices from the train dataset
        train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
        steps_per_epoch = len(train_dataset) // train_batch_size
        # train
786
        for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
Suraj Patil's avatar
Suraj Patil committed
787
            batch = next(train_loader)
788
            batch = shard(batch)
Suraj Patil's avatar
Suraj Patil committed
789
790
791
            state, train_metric = p_train_step(state, batch)
            train_metrics.append(train_metric)

792
            cur_step = epoch * (len(train_dataset) // train_batch_size) + step
Suraj Patil's avatar
Suraj Patil committed
793

794
            if cur_step % training_args.logging_steps == 0 and cur_step > 0:
795
796
797
798
799
                # Save metrics
                train_metric = unreplicate(train_metric)
                train_time += time.time() - train_start
                if has_tensorboard and jax.process_index() == 0:
                    write_train_metric(summary_writer, train_metrics, train_time, cur_step)
Suraj Patil's avatar
Suraj Patil committed
800

801
                epochs.write(
Sylvain Gugger's avatar
Sylvain Gugger committed
802
803
                    f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
                    f" {train_metric['learning_rate'].mean()})"
804
805
806
                )

                train_metrics = []
Suraj Patil's avatar
Suraj Patil committed
807

808
809
810
            if cur_step % training_args.eval_steps == 0 and cur_step > 0:
                # ======================== Evaluating ==============================
                eval_metrics = []
811
812
                eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
                eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
813
814
815
                for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
                    # Model forward
                    batch = next(eval_loader)
816
817
818
                    metrics = pad_shard_unpad(p_eval_step, static_return=True)(
                        state.params, batch, min_device_batch=per_device_eval_batch_size
                    )
819
820
821
822
                    eval_metrics.append(metrics)

                # normalize eval metrics
                eval_metrics = get_metrics(eval_metrics)
823
                eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
824
825
826
827
828
829
830

                try:
                    eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
                except OverflowError:
                    eval_metrics["perplexity"] = float("inf")

                # Print metrics and update progress bar
Sylvain Gugger's avatar
Sylvain Gugger committed
831
832
833
834
                desc = (
                    f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity:"
                    f" {eval_metrics['perplexity']})"
                )
835
836
                epochs.write(desc)
                epochs.desc = desc
Suraj Patil's avatar
Suraj Patil committed
837

838
839
840
841
842
843
844
845
                # Save metrics
                if has_tensorboard and jax.process_index() == 0:
                    write_eval_metric(summary_writer, eval_metrics, cur_step)

            if cur_step % training_args.save_steps == 0 and cur_step > 0:
                # save checkpoint after each epoch and push checkpoint to the hub
                if jax.process_index() == 0:
                    params = jax.device_get(unreplicate(state.params))
846
847
848
849
                    model.save_pretrained(training_args.output_dir, params=params)
                    tokenizer.save_pretrained(training_args.output_dir)
                    if training_args.push_to_hub:
                        repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
Suraj Patil's avatar
Suraj Patil committed
850

Suraj Patil's avatar
Suraj Patil committed
851
852
853
    # Eval after training
    if training_args.do_eval:
        eval_metrics = []
854
855
        eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
        eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
Suraj Patil's avatar
Suraj Patil committed
856
857
        for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
            # Model forward
858
859
860
861
            batch = next(eval_loader)
            metrics = pad_shard_unpad(p_eval_step, static_return=True)(
                state.params, batch, min_device_batch=per_device_eval_batch_size
            )
Suraj Patil's avatar
Suraj Patil committed
862
863
864
865
            eval_metrics.append(metrics)

        # normalize eval metrics
        eval_metrics = get_metrics(eval_metrics)
866
        eval_metrics = jax.tree_util.tree_map(lambda x: jnp.mean(x).item(), eval_metrics)
Suraj Patil's avatar
Suraj Patil committed
867
868
869
870
871
872
873
874
875
876
877
878

        try:
            eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
        except OverflowError:
            eval_metrics["perplexity"] = float("inf")

        if jax.process_index() == 0:
            eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
            path = os.path.join(training_args.output_dir, "eval_results.json")
            with open(path, "w") as f:
                json.dump(eval_metrics, f, indent=4, sort_keys=True)

Suraj Patil's avatar
Suraj Patil committed
879
880
881

if __name__ == "__main__":
    main()