"tests/modeling_tf_t5_test.py" did not exist on "b67fa1a8d2302d808ecb9d95355181eaf21ee3b6"
run_flax_glue.py 25.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning a 馃 Flax Transformers model for sequence classification on GLUE."""
Suraj Patil's avatar
Suraj Patil committed
17
import json
18
19
20
import logging
import os
import random
21
import sys
22
import time
23
from dataclasses import dataclass, field
24
from itertools import chain
25
from pathlib import Path
26
from typing import Any, Callable, Dict, Optional, Tuple
27
28

import datasets
29
import numpy as np
30
from datasets import load_dataset, load_metric
31
from tqdm import tqdm
32
33
34
35
36
37
38
39

import jax
import jax.numpy as jnp
import optax
import transformers
from flax import struct, traverse_util
from flax.jax_utils import replicate, unreplicate
from flax.training import train_state
40
from flax.training.common_utils import get_metrics, onehot, shard
41
from huggingface_hub import Repository
Suraj Patil's avatar
Suraj Patil committed
42
43
44
45
from transformers import (
    AutoConfig,
    AutoTokenizer,
    FlaxAutoModelForSequenceClassification,
46
    HfArgumentParser,
Suraj Patil's avatar
Suraj Patil committed
47
    PretrainedConfig,
48
    TrainingArguments,
Suraj Patil's avatar
Suraj Patil committed
49
50
    is_tensorboard_available,
)
51
from transformers.utils import check_min_version, get_full_repo_name
52
53
54


logger = logging.getLogger(__name__)
55
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
Sylvain Gugger's avatar
Sylvain Gugger committed
56
check_min_version("4.18.0.dev0")
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

Array = Any
Dataset = datasets.arrow_dataset.Dataset
PRNGKey = Any


task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}


76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    use_slow_tokenizer: Optional[bool] = field(
        default=False,
        metadata={"help": "If passed, will use a slow tokenizer (not backed by the 馃 Tokenizers library)."},
    )
    cache_dir: Optional[str] = field(
96
        default=None,
97
98
99
100
101
102
103
104
105
106
107
108
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
            "with private models)."
        },
109
    )
110
111
112
113
114
115
116
117
118
119


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    task_name: Optional[str] = field(
        default=None, metadata={"help": f"The name of the glue task to train on. choices {list(task_to_keys.keys())}"}
120
    )
121
122
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
123
    )
124
125
    train_file: Optional[str] = field(
        default=None, metadata={"help": "The input training data file (a csv or JSON file)."}
126
    )
127
128
129
    validation_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."},
130
    )
131
132
133
    test_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."},
134
    )
135
136
    text_column_name: Optional[str] = field(
        default=None, metadata={"help": "The column name of text to input in the file (a csv or JSON file)."}
137
    )
138
139
    label_column_name: Optional[str] = field(
        default=None, metadata={"help": "The column name of label to input in the file (a csv or JSON file)."}
140
    )
141
142
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
143
    )
144
    preprocessing_num_workers: Optional[int] = field(
145
        default=None,
146
        metadata={"help": "The number of processes to use for the preprocessing."},
147
    )
148
149
150
151
152
153
    max_seq_length: int = field(
        default=None,
        metadata={
            "help": "The maximum total input sequence length after tokenization. If set, sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
154
    )
155
156
157
158
159
160
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
            "value if set."
        },
161
    )
162
163
164
165
166
167
168
169
170
171
172
173
174
    max_eval_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
            "value if set."
        },
    )
    max_predict_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
            "value if set."
        },
175
    )
176

177
178
179
180
181
182
183
184
185
186
187
    def __post_init__(self):
        if self.task_name is None and self.train_file is None and self.validation_file is None:
            raise ValueError("Need either a dataset name or a training/validation file.")
        else:
            if self.train_file is not None:
                extension = self.train_file.split(".")[-1]
                assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
            if self.validation_file is not None:
                extension = self.validation_file.split(".")[-1]
                assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
        self.task_name = self.task_name.lower() if type(self.task_name) == str else self.task_name
188
189
190
191
192
193
194


def create_train_state(
    model: FlaxAutoModelForSequenceClassification,
    learning_rate_fn: Callable[[int], float],
    is_regression: bool,
    num_labels: int,
195
    weight_decay: float,
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
) -> train_state.TrainState:
    """Create initial training state."""

    class TrainState(train_state.TrainState):
        """Train state with an Optax optimizer.

        The two functions below differ depending on whether the task is classification
        or regression.

        Args:
          logits_fn: Applied to last layer to obtain the logits.
          loss_fn: Function to compute the loss.
        """

        logits_fn: Callable = struct.field(pytree_node=False)
        loss_fn: Callable = struct.field(pytree_node=False)

213
214
215
216
217
218
219
220
221
222
223
    # We use Optax's "masking" functionality to not apply weight decay
    # to bias and LayerNorm scale parameters. decay_mask_fn returns a
    # mask boolean with the same structure as the parameters.
    # The mask is True for parameters that should be decayed.
    def decay_mask_fn(params):
        flat_params = traverse_util.flatten_dict(params)
        flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
        return traverse_util.unflatten_dict(flat_mask)

    tx = optax.adamw(
        learning_rate=learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=weight_decay, mask=decay_mask_fn
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
    )

    if is_regression:

        def mse_loss(logits, labels):
            return jnp.mean((logits[..., 0] - labels) ** 2)

        return TrainState.create(
            apply_fn=model.__call__,
            params=model.params,
            tx=tx,
            logits_fn=lambda logits: logits[..., 0],
            loss_fn=mse_loss,
        )
    else:  # Classification.

        def cross_entropy_loss(logits, labels):
            xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels))
            return jnp.mean(xentropy)

        return TrainState.create(
            apply_fn=model.__call__,
            params=model.params,
            tx=tx,
            logits_fn=lambda logits: logits.argmax(-1),
            loss_fn=cross_entropy_loss,
        )


def create_learning_rate_fn(
    train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]:
    """Returns a linear warmup, linear_decay learning rate function."""
    steps_per_epoch = train_ds_size // train_batch_size
    num_train_steps = steps_per_epoch * num_train_epochs
    warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
    decay_fn = optax.linear_schedule(
        init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
    )
    schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
    return schedule_fn


def glue_train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
    """Returns shuffled batches of size `batch_size` from truncated `train dataset`, sharded over all local devices."""
    steps_per_epoch = len(dataset) // batch_size
    perms = jax.random.permutation(rng, len(dataset))
    perms = perms[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    perms = perms.reshape((steps_per_epoch, batch_size))

    for perm in perms:
        batch = dataset[perm]
276
        batch = {k: np.array(v) for k, v in batch.items()}
277
278
279
280
281
282
283
284
285
        batch = shard(batch)

        yield batch


def glue_eval_data_collator(dataset: Dataset, batch_size: int):
    """Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices."""
    for i in range(len(dataset) // batch_size):
        batch = dataset[i * batch_size : (i + 1) * batch_size]
286
        batch = {k: np.array(v) for k, v in batch.items()}
287
288
289
290
291
292
        batch = shard(batch)

        yield batch


def main():
293
294
295
296
297
298
299
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
300
301
302

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
303
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
304
305
306
307
308
309
310
311
312
313
314
315
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Setup logging, we only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

316
    # Handle the repository creation
317
318
319
320
321
    if training_args.push_to_hub:
        if training_args.hub_model_id is None:
            repo_name = get_full_repo_name(
                Path(training_args.output_dir).absolute().name, token=training_args.hub_token
            )
322
        else:
323
324
            repo_name = training_args.hub_model_id
        repo = Repository(training_args.output_dir, clone_from=repo_name)
325

326
327
328
329
330
331
332
333
334
335
336
337
    # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
    # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).

    # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
    # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
    # label if at least two columns are provided.

    # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
    # single column. You can easily tweak this behavior (see below)

    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
338
    if data_args.task_name is not None:
339
        # Downloading and loading a dataset from the hub.
340
        raw_datasets = load_dataset("glue", data_args.task_name)
341
342
343
    else:
        # Loading the dataset from local csv or json file.
        data_files = {}
344
345
346
347
348
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        extension = (data_args.train_file if data_args.train_file is not None else data_args.valid_file).split(".")[-1]
349
350
351
352
353
        raw_datasets = load_dataset(extension, data_files=data_files)
    # See more about loading any type of standard or custom dataset at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    # Labels
354
355
    if data_args.task_name is not None:
        is_regression = data_args.task_name == "stsb"
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
        if not is_regression:
            label_list = raw_datasets["train"].features["label"].names
            num_labels = len(label_list)
        else:
            num_labels = 1
    else:
        # Trying to have good defaults here, don't hesitate to tweak to your needs.
        is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
        if is_regression:
            num_labels = 1
        else:
            # A useful fast method:
            # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
            label_list = raw_datasets["train"].unique("label")
            label_list.sort()  # Let's sort it for determinism
            num_labels = len(label_list)

    # Load pretrained model and tokenizer
374
375
376
377
378
379
380
    config = AutoConfig.from_pretrained(
        model_args.model_name_or_path, num_labels=num_labels, finetuning_task=data_args.task_name
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path, use_fast=not model_args.use_slow_tokenizer
    )
    model = FlaxAutoModelForSequenceClassification.from_pretrained(model_args.model_name_or_path, config=config)
381
382

    # Preprocessing the datasets
383
384
    if data_args.task_name is not None:
        sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
    else:
        # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
        non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
        if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
            sentence1_key, sentence2_key = "sentence1", "sentence2"
        else:
            if len(non_label_column_names) >= 2:
                sentence1_key, sentence2_key = non_label_column_names[:2]
            else:
                sentence1_key, sentence2_key = non_label_column_names[0], None

    # Some models have set the order of the labels to use, so let's make sure we do use it.
    label_to_id = None
    if (
        model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
400
        and data_args.task_name is not None
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
        and not is_regression
    ):
        # Some have all caps in their config, some don't.
        label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
        if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
            logger.info(
                f"The configuration of the model provided the following label correspondence: {label_name_to_id}. "
                "Using it!"
            )
            label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)}
        else:
            logger.warning(
                "Your model seems to have been trained with labels, but they don't match the dataset: ",
                f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
                "\nIgnoring the model labels as a result.",
            )
417
    elif data_args.task_name is None:
418
419
420
421
422
423
424
        label_to_id = {v: i for i, v in enumerate(label_list)}

    def preprocess_function(examples):
        # Tokenize the texts
        texts = (
            (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
        )
425
        result = tokenizer(*texts, padding="max_length", max_length=data_args.max_seq_length, truncation=True)
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440

        if "label" in examples:
            if label_to_id is not None:
                # Map labels to IDs (not necessary for GLUE tasks)
                result["labels"] = [label_to_id[l] for l in examples["label"]]
            else:
                # In all cases, rename the column to labels because the model will expect that.
                result["labels"] = examples["label"]
        return result

    processed_datasets = raw_datasets.map(
        preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names
    )

    train_dataset = processed_datasets["train"]
441
    eval_dataset = processed_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
442
443
444
445
446
447

    # Log a few random samples from the training set:
    for index in random.sample(range(len(train_dataset)), 3):
        logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")

    # Define a summary writer
Suraj Patil's avatar
Suraj Patil committed
448
449
450
451
452
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard and jax.process_index() == 0:
        try:
            from flax.metrics.tensorboard import SummaryWriter

453
454
            summary_writer = SummaryWriter(training_args.output_dir)
            summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)})
Suraj Patil's avatar
Suraj Patil committed
455
456
457
458
459
460
461
462
463
464
        except ImportError as ie:
            has_tensorboard = False
            logger.warning(
                f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
            )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not installed: "
            "Please run pip install tensorboard to enable."
        )
465

466
    def write_train_metric(summary_writer, train_metrics, train_time, step):
467
468
469
470
471
472
473
474
        summary_writer.scalar("train_time", train_time, step)

        train_metrics = get_metrics(train_metrics)
        for key, vals in train_metrics.items():
            tag = f"train_{key}"
            for i, val in enumerate(vals):
                summary_writer.scalar(tag, val, step - len(vals) + i + 1)

475
    def write_eval_metric(summary_writer, eval_metrics, step):
476
477
478
        for metric_name, value in eval_metrics.items():
            summary_writer.scalar(f"eval_{metric_name}", value, step)

479
480
    num_epochs = int(training_args.num_train_epochs)
    rng = jax.random.PRNGKey(training_args.seed)
481
    dropout_rngs = jax.random.split(rng, jax.local_device_count())
482

483
484
    train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count()
    eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count()
485
486

    learning_rate_fn = create_learning_rate_fn(
487
488
489
490
491
        len(train_dataset),
        train_batch_size,
        training_args.num_train_epochs,
        training_args.warmup_steps,
        training_args.learning_rate,
492
493
    )

494
    state = create_train_state(
495
        model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=training_args.weight_decay
496
    )
497
498
499
500
501
502

    # define step functions
    def train_step(
        state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey
    ) -> Tuple[train_state.TrainState, float]:
        """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
503
        dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
504
505
506
507
508
        targets = batch.pop("labels")

        def loss_fn(params):
            logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
            loss = state.loss_fn(logits, targets)
509
            return loss
510

511
512
        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(state.params)
513
514
515
        grad = jax.lax.pmean(grad, "batch")
        new_state = state.apply_gradients(grads=grad)
        metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch")
516
        return new_state, metrics, new_dropout_rng
517
518
519
520
521
522
523
524
525

    p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))

    def eval_step(state, batch):
        logits = state.apply_fn(**batch, params=state.params, train=False)[0]
        return state.logits_fn(logits)

    p_eval_step = jax.pmap(eval_step, axis_name="batch")

526
527
    if data_args.task_name is not None:
        metric = load_metric("glue", data_args.task_name)
528
529
530
531
532
533
    else:
        metric = load_metric("accuracy")

    logger.info(f"===== Starting training ({num_epochs} epochs) =====")
    train_time = 0

534
535
536
    # make sure weights are replicated on each device
    state = replicate(state)

537
538
539
540
    steps_per_epoch = len(train_dataset) // train_batch_size
    total_steps = steps_per_epoch * num_epochs
    epochs = tqdm(range(num_epochs), desc=f"Epoch ... (0/{num_epochs})", position=0)
    for epoch in epochs:
541
542
543

        train_start = time.time()
        train_metrics = []
544
545

        # Create sampling rng
546
        rng, input_rng = jax.random.split(rng)
547
548

        # train
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
        train_loader = glue_train_data_collator(input_rng, train_dataset, train_batch_size)
        for step, batch in enumerate(
            tqdm(
                train_loader,
                total=steps_per_epoch,
                desc="Training...",
                position=1,
            ),
        ):
            state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs)
            train_metrics.append(train_metric)

            cur_step = (epoch * steps_per_epoch) + (step + 1)

            if cur_step % training_args.logging_steps == 0 and cur_step > 0:
                # Save metrics
                train_metric = unreplicate(train_metric)
                train_time += time.time() - train_start
                if has_tensorboard and jax.process_index() == 0:
                    write_train_metric(summary_writer, train_metrics, train_time, cur_step)

                epochs.write(
                    f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
                )

                train_metrics = []

            if (cur_step % training_args.eval_steps == 0 or cur_step % steps_per_epoch == 0) and cur_step > 0:

                eval_metrics = {}
                # evaluate
                eval_loader = glue_eval_data_collator(eval_dataset, eval_batch_size)
                for batch in tqdm(
                    eval_loader,
                    total=len(eval_dataset) // eval_batch_size,
                    desc="Evaluating ...",
                    position=2,
                ):
                    labels = batch.pop("labels")
                    predictions = p_eval_step(state, batch)
                    metric.add_batch(predictions=chain(*predictions), references=chain(*labels))

                # evaluate also on leftover examples (not divisible by batch_size)
                num_leftover_samples = len(eval_dataset) % eval_batch_size

                # make sure leftover batch is evaluated on one device
                if num_leftover_samples > 0 and jax.process_index() == 0:
                    # take leftover samples
                    batch = eval_dataset[-num_leftover_samples:]
                    batch = {k: np.array(v) for k, v in batch.items()}

                    labels = batch.pop("labels")
                    predictions = eval_step(unreplicate(state), batch)
                    metric.add_batch(predictions=predictions, references=labels)

                eval_metric = metric.compute()

                logger.info(f"Step... ({cur_step}/{total_steps} | Eval metrics: {eval_metric})")

                if has_tensorboard and jax.process_index() == 0:
                    write_eval_metric(summary_writer, eval_metrics, cur_step)

            if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps):
                # save checkpoint after each epoch and push checkpoint to the hub
                if jax.process_index() == 0:
                    params = jax.device_get(unreplicate(state.params))
                    model.save_pretrained(training_args.output_dir, params=params)
                    tokenizer.save_pretrained(training_args.output_dir)
                    if training_args.push_to_hub:
                        repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
            epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
620

Suraj Patil's avatar
Suraj Patil committed
621
622
623
    # save the eval metrics in json
    if jax.process_index() == 0:
        eval_metric = {f"eval_{metric_name}": value for metric_name, value in eval_metric.items()}
624
        path = os.path.join(training_args.output_dir, "eval_results.json")
Suraj Patil's avatar
Suraj Patil committed
625
626
627
        with open(path, "w") as f:
            json.dump(eval_metric, f, indent=4, sort_keys=True)

628
629
630

if __name__ == "__main__":
    main()