run_tf.py 4.03 KB
Newer Older
Philipp Schmid's avatar
Philipp Schmid committed
1
2
3
4
5
6
7
import argparse
import logging
import sys
import time

import tensorflow as tf
from datasets import load_dataset
8
from packaging.version import parse
Philipp Schmid's avatar
Philipp Schmid committed
9
10
11
12

from transformers import AutoTokenizer, TFAutoModelForSequenceClassification


13
14
15
16
17
18
19
20
21
22
23
24
25
try:
    import tf_keras as keras
except (ModuleNotFoundError, ImportError):
    import keras

    if parse(keras.__version__).major > 2:
        raise ValueError(
            "Your currently installed version of Keras is Keras 3, but this is not yet supported in "
            "Transformers. Please install the backwards-compatible tf-keras package with "
            "`pip install tf-keras`."
        )


Philipp Schmid's avatar
Philipp Schmid committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # Hyperparameters sent by the client are passed as command-line arguments to the script.
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--per_device_train_batch_size", type=int, default=16)
    parser.add_argument("--per_device_eval_batch_size", type=int, default=8)
    parser.add_argument("--model_name_or_path", type=str)
    parser.add_argument("--learning_rate", type=str, default=5e-5)
    parser.add_argument("--do_train", type=bool, default=True)
    parser.add_argument("--do_eval", type=bool, default=True)
    parser.add_argument("--output_dir", type=str)

    args, _ = parser.parse_known_args()

    # overwrite batch size until we have tf_glue.py
    args.per_device_train_batch_size = 16
    args.per_device_eval_batch_size = 16

    # Set up logging
    logger = logging.getLogger(__name__)

    logging.basicConfig(
        level=logging.getLevelName("INFO"),
        handlers=[logging.StreamHandler(sys.stdout)],
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    )

    # Load model and tokenizer
    model = TFAutoModelForSequenceClassification.from_pretrained(args.model_name_or_path)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

    # Load dataset
    train_dataset, test_dataset = load_dataset("imdb", split=["train", "test"])
    train_dataset = train_dataset.shuffle().select(range(5000))  # smaller the size for train dataset to 5k
    test_dataset = test_dataset.shuffle().select(range(500))  # smaller the size for test dataset to 500

    # Preprocess train dataset
    train_dataset = train_dataset.map(
        lambda e: tokenizer(e["text"], truncation=True, padding="max_length"), batched=True
    )
    train_dataset.set_format(type="tensorflow", columns=["input_ids", "attention_mask", "label"])

    train_features = {
        x: train_dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length])
        for x in ["input_ids", "attention_mask"]
    }
    tf_train_dataset = tf.data.Dataset.from_tensor_slices((train_features, train_dataset["label"])).batch(
        args.per_device_train_batch_size
    )

    # Preprocess test dataset
    test_dataset = test_dataset.map(
        lambda e: tokenizer(e["text"], truncation=True, padding="max_length"), batched=True
    )
    test_dataset.set_format(type="tensorflow", columns=["input_ids", "attention_mask", "label"])

    test_features = {
        x: test_dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length])
        for x in ["input_ids", "attention_mask"]
    }
    tf_test_dataset = tf.data.Dataset.from_tensor_slices((test_features, test_dataset["label"])).batch(
        args.per_device_eval_batch_size
    )

    # fine optimizer and loss
92
93
94
    optimizer = keras.optimizers.Adam(learning_rate=args.learning_rate)
    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    metrics = [keras.metrics.SparseCategoricalAccuracy()]
Philipp Schmid's avatar
Philipp Schmid committed
95
96
97
98
99
100
101
    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

    start_train_time = time.time()
    train_results = model.fit(tf_train_dataset, epochs=args.epochs, batch_size=args.per_device_train_batch_size)
    end_train_time = time.time() - start_train_time

    logger.info("*** Train ***")
102
    logger.info(f"train_runtime = {end_train_time}")
Philipp Schmid's avatar
Philipp Schmid committed
103
    for key, value in train_results.history.items():
104
        logger.info(f"  {key} = {value}")