run_tf_ner.py 25.2 KB
Newer Older
1
# coding=utf-8
Aymeric Augustin's avatar
Aymeric Augustin committed
2
import collections
3
4
import datetime
import glob
Aymeric Augustin's avatar
Aymeric Augustin committed
5
6
import math
import os
7
import re
Aymeric Augustin's avatar
Aymeric Augustin committed
8

9
import numpy as np
Aymeric Augustin's avatar
Aymeric Augustin committed
10
11
import tensorflow as tf
from absl import app, flags, logging
Julien Chaumond's avatar
Julien Chaumond committed
12
from fastprogress import master_bar, progress_bar
13
from seqeval import metrics
14

Aymeric Augustin's avatar
Aymeric Augustin committed
15
16
from transformers import (
    TF2_WEIGHTS_NAME,
17
18
19
    TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
    AutoConfig,
    AutoTokenizer,
Aymeric Augustin's avatar
Aymeric Augustin committed
20
    GradientAccumulator,
Julien Chaumond's avatar
Julien Chaumond committed
21
    PreTrainedTokenizer,
22
    TFAutoModelForTokenClassification,
Aymeric Augustin's avatar
Aymeric Augustin committed
23
24
    create_optimizer,
)
25
26
27
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file


28
29
MODEL_CONFIG_CLASSES = list(TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
30
31
32


flags.DEFINE_string(
Julien Chaumond's avatar
Julien Chaumond committed
33
    "data_dir", None, "The input data dir. Should contain the .conll files (or other data files) for the task."
34
)
35
36

flags.DEFINE_string(
Julien Chaumond's avatar
Julien Chaumond committed
37
    "model_name_or_path", None, "Path to pretrained model or model identifier from huggingface.co/models",
38
)
39

40
flags.DEFINE_string("output_dir", None, "The output directory where the model checkpoints will be written.")
41
42

flags.DEFINE_string(
43
44
    "labels", "", "Path to a file containing all labels. If not specified, CoNLL-2003 labels are used."
)
45

Julien Chaumond's avatar
Julien Chaumond committed
46
flags.DEFINE_string("config_name", None, "Pretrained config name or path if not the same as model_name")
47

Julien Chaumond's avatar
Julien Chaumond committed
48
flags.DEFINE_string("tokenizer_name", None, "Pretrained tokenizer name or path if not the same as model_name")
49

Julien Chaumond's avatar
Julien Chaumond committed
50
flags.DEFINE_string("cache_dir", None, "Where do you want to store the pre-trained models downloaded from s3")
51
52

flags.DEFINE_integer(
53
54
    "max_seq_length",
    128,
55
56
    "The maximum total input sentence length after tokenization. "
    "Sequences longer than this will be truncated, sequences shorter "
57
58
    "will be padded.",
)
59
60

flags.DEFINE_string(
61
62
    "tpu",
    None,
63
64
    "The Cloud TPU to use for training. This should be either the name "
    "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
65
66
    "url.",
)
67

68
flags.DEFINE_integer("num_tpu_cores", 8, "Total number of TPU cores to use.")
69

70
flags.DEFINE_boolean("do_train", False, "Whether to run training.")
71

72
flags.DEFINE_boolean("do_eval", False, "Whether to run eval on the dev set.")
73

74
flags.DEFINE_boolean("do_predict", False, "Whether to run predictions on the test set.")
75
76

flags.DEFINE_boolean(
77
78
    "evaluate_during_training", False, "Whether to run evaluation during training at each logging step."
)
79

80
flags.DEFINE_boolean("do_lower_case", False, "Set this flag if you are using an uncased model.")
81

82
flags.DEFINE_integer("per_device_train_batch_size", 8, "Batch size per GPU/CPU/TPU for training.")
83

84
flags.DEFINE_integer("per_device_eval_batch_size", 8, "Batch size per GPU/CPU/TPU for evaluation.")
85
86

flags.DEFINE_integer(
87
88
    "gradient_accumulation_steps", 1, "Number of updates steps to accumulate before performing a backward/update pass."
)
89

90
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
91

92
flags.DEFINE_float("weight_decay", 0.0, "Weight decay if we apply some.")
93

94
flags.DEFINE_float("adam_epsilon", 1e-8, "Epsilon for Adam optimizer.")
95

96
flags.DEFINE_float("max_grad_norm", 1.0, "Max gradient norm.")
97

98
flags.DEFINE_integer("num_train_epochs", 3, "Total number of training epochs to perform.")
99
100

flags.DEFINE_integer(
101
102
    "max_steps", -1, "If > 0: set total number of training steps to perform. Override num_train_epochs."
)
103

104
flags.DEFINE_integer("warmup_steps", 0, "Linear warmup over warmup_steps.")
105

106
flags.DEFINE_integer("logging_steps", 50, "Log every X updates steps.")
107

108
flags.DEFINE_integer("save_steps", 50, "Save checkpoint every X updates steps.")
109
110

flags.DEFINE_boolean(
111
112
113
114
    "eval_all_checkpoints",
    False,
    "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
)
115

Julien Chaumond's avatar
Julien Chaumond committed
116
flags.DEFINE_boolean("no_cuda", False, "Avoid using CUDA even if it is available")
117

118
flags.DEFINE_boolean("overwrite_output_dir", False, "Overwrite the content of the output directory")
119

120
flags.DEFINE_boolean("overwrite_cache", False, "Overwrite the cached training and evaluation sets")
121

122
flags.DEFINE_integer("seed", 42, "random seed for initialization")
123

124
flags.DEFINE_boolean("fp16", False, "Whether to use 16-bit (mixed) precision instead of 32-bit")
125
126

flags.DEFINE_string(
127
128
    "gpus",
    "0",
129
    "Comma separated list of gpus devices. If only one, switch to single "
130
131
    "gpu strategy, if None takes all the gpus available.",
)
132
133


134
135
136
137
138
139
def train(
    args, strategy, train_dataset, tokenizer, model, num_train_examples, labels, train_batch_size, pad_token_label_id
):
    if args["max_steps"] > 0:
        num_train_steps = args["max_steps"] * args["gradient_accumulation_steps"]
        args["num_train_epochs"] = 1
140
    else:
141
142
143
144
145
        num_train_steps = (
            math.ceil(num_train_examples / train_batch_size)
            // args["gradient_accumulation_steps"]
            * args["num_train_epochs"]
        )
146
147
148
149

    writer = tf.summary.create_file_writer("/tmp/mylogs")

    with strategy.scope():
Julien Plu's avatar
Julien Plu committed
150
151
152
        loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
153
        optimizer = create_optimizer(args["learning_rate"], num_train_steps, args["warmup_steps"])
154

155
156
        if args["fp16"]:
            optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(optimizer, "dynamic")
157

158
        loss_metric = tf.keras.metrics.Mean(name="loss", dtype=tf.float32)
159
        gradient_accumulator = GradientAccumulator()
160

161
162
    logging.info("***** Running training *****")
    logging.info("  Num examples = %d", num_train_examples)
163
164
165
166
167
168
169
    logging.info("  Num Epochs = %d", args["num_train_epochs"])
    logging.info("  Instantaneous batch size per device = %d", args["per_device_train_batch_size"])
    logging.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        train_batch_size * args["gradient_accumulation_steps"],
    )
    logging.info("  Gradient Accumulation steps = %d", args["gradient_accumulation_steps"])
170
171
172
173
174
175
176
177
178
179
    logging.info("  Total training steps = %d", num_train_steps)

    model.summary()

    @tf.function
    def apply_gradients():
        grads_and_vars = []

        for gradient, variable in zip(gradient_accumulator.gradients, model.trainable_variables):
            if gradient is not None:
180
                scaled_gradient = gradient / (args["n_device"] * args["gradient_accumulation_steps"])
181
182
183
184
                grads_and_vars.append((scaled_gradient, variable))
            else:
                grads_and_vars.append((gradient, variable))

185
        optimizer.apply_gradients(grads_and_vars, args["max_grad_norm"])
186
187
188
189
190
        gradient_accumulator.reset()

    @tf.function
    def train_step(train_features, train_labels):
        def step_fn(train_features, train_labels):
Julien Chaumond's avatar
Julien Chaumond committed
191
            inputs = {"attention_mask": train_features["attention_mask"], "training": True}
192

Julien Chaumond's avatar
Julien Chaumond committed
193
194
            if "token_type_ids" in train_features:
                inputs["token_type_ids"] = train_features["token_type_ids"]
195
196

            with tf.GradientTape() as tape:
197
                logits = model(train_features["input_ids"], **inputs)[0]
Julien Plu's avatar
Julien Plu committed
198
199
200
                active_loss = tf.reshape(train_labels, (-1,)) != pad_token_label_id
                active_logits = tf.boolean_mask(tf.reshape(logits, (-1, len(labels))), active_loss)
                active_labels = tf.boolean_mask(tf.reshape(train_labels, (-1,)), active_loss)
201
202
203
204
205
206
207
208
209
210
211
212
213
214
                cross_entropy = loss_fct(active_labels, active_logits)
                loss = tf.reduce_sum(cross_entropy) * (1.0 / train_batch_size)
                grads = tape.gradient(loss, model.trainable_variables)

                gradient_accumulator(grads)

            return cross_entropy

        per_example_losses = strategy.experimental_run_v2(step_fn, args=(train_features, train_labels))
        mean_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_example_losses, axis=0)

        return mean_loss

    current_time = datetime.datetime.now()
215
    train_iterator = master_bar(range(args["num_train_epochs"]))
216
217
218
219
    global_step = 0
    logging_loss = 0.0

    for epoch in train_iterator:
220
221
222
        epoch_iterator = progress_bar(
            train_dataset, total=num_train_steps, parent=train_iterator, display=args["n_device"] > 1
        )
223
224
225
226
227
228
        step = 1

        with strategy.scope():
            for train_features, train_labels in epoch_iterator:
                loss = train_step(train_features, train_labels)

229
                if step % args["gradient_accumulation_steps"] == 0:
230
231
232
233
234
235
                    strategy.experimental_run_v2(apply_gradients)

                    loss_metric(loss)

                    global_step += 1

236
                    if args["logging_steps"] > 0 and global_step % args["logging_steps"] == 0:
237
                        # Log metrics
238
239
240
241
242
243
                        if (
                            args["n_device"] == 1 and args["evaluate_during_training"]
                        ):  # Only evaluate when single GPU otherwise metrics may not average well
                            y_true, y_pred, eval_loss = evaluate(
                                args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev"
                            )
244
                            report = metrics.classification_report(y_true, y_pred, digits=4)
245

246
247
                            logging.info("Eval at step " + str(global_step) + "\n" + report)
                            logging.info("eval_loss: " + str(eval_loss))
248

249
250
251
252
253
254
255
256
257
                            precision = metrics.precision_score(y_true, y_pred)
                            recall = metrics.recall_score(y_true, y_pred)
                            f1 = metrics.f1_score(y_true, y_pred)

                            with writer.as_default():
                                tf.summary.scalar("eval_loss", eval_loss, global_step)
                                tf.summary.scalar("precision", precision, global_step)
                                tf.summary.scalar("recall", recall, global_step)
                                tf.summary.scalar("f1", f1, global_step)
258

259
260
261
262
263
                        lr = optimizer.learning_rate
                        learning_rate = lr(step)

                        with writer.as_default():
                            tf.summary.scalar("lr", learning_rate, global_step)
264
265
266
267
                            tf.summary.scalar(
                                "loss", (loss_metric.result() - logging_loss) / args["logging_steps"], global_step
                            )

268
269
270
271
272
                        logging_loss = loss_metric.result()

                    with writer.as_default():
                        tf.summary.scalar("loss", loss_metric.result(), step=step)

273
                    if args["save_steps"] > 0 and global_step % args["save_steps"] == 0:
274
                        # Save model checkpoint
275
                        output_dir = os.path.join(args["output_dir"], "checkpoint-{}".format(global_step))
276
277
278

                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
279

280
281
                        model.save_pretrained(output_dir)
                        logging.info("Saving model checkpoint to %s", output_dir)
282
283

                train_iterator.child.comment = f"loss : {loss_metric.result()}"
284
285
                step += 1

286
        train_iterator.write(f"loss epoch {epoch + 1}: {loss_metric.result()}")
287
288
289
290
291
292
293

        loss_metric.reset_states()

    logging.info("  Training took time = {}".format(datetime.datetime.now() - current_time))


def evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode):
294
295
296
297
    eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"]
    eval_dataset, size = load_and_cache_examples(
        args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode=mode
    )
298
299
300
301
    eval_dataset = strategy.experimental_distribute_dataset(eval_dataset)
    preds = None
    num_eval_steps = math.ceil(size / eval_batch_size)
    master = master_bar(range(1))
302
    eval_iterator = progress_bar(eval_dataset, total=num_eval_steps, parent=master, display=args["n_device"] > 1)
303
304
305
306
307
308
309
310
    loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
    loss = 0.0

    logging.info("***** Running evaluation *****")
    logging.info("  Num examples = %d", size)
    logging.info("  Batch size = %d", eval_batch_size)

    for eval_features, eval_labels in eval_iterator:
Julien Chaumond's avatar
Julien Chaumond committed
311
        inputs = {"attention_mask": eval_features["attention_mask"], "training": False}
312

Julien Chaumond's avatar
Julien Chaumond committed
313
314
        if "token_type_ids" in eval_features:
            inputs["token_type_ids"] = eval_features["token_type_ids"]
315
316

        with strategy.scope():
317
            logits = model(eval_features["input_ids"], **inputs)[0]
Julien Plu's avatar
Julien Plu committed
318
319
320
            active_loss = tf.reshape(eval_labels, (-1,)) != pad_token_label_id
            active_logits = tf.boolean_mask(tf.reshape(logits, (-1, len(labels))), active_loss)
            active_labels = tf.boolean_mask(tf.reshape(eval_labels, (-1,)), active_loss)
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
            cross_entropy = loss_fct(active_labels, active_logits)
            loss += tf.reduce_sum(cross_entropy) * (1.0 / eval_batch_size)

        if preds is None:
            preds = logits.numpy()
            label_ids = eval_labels.numpy()
        else:
            preds = np.append(preds, logits.numpy(), axis=0)
            label_ids = np.append(label_ids, eval_labels.numpy(), axis=0)

    preds = np.argmax(preds, axis=2)
    y_pred = [[] for _ in range(label_ids.shape[0])]
    y_true = [[] for _ in range(label_ids.shape[0])]
    loss = loss / num_eval_steps

    for i in range(label_ids.shape[0]):
        for j in range(label_ids.shape[1]):
            if label_ids[i, j] != pad_token_label_id:
                y_pred[i].append(labels[preds[i, j] - 1])
                y_true[i].append(labels[label_ids[i, j] - 1])

    return y_true, y_pred, loss.numpy()


Julien Chaumond's avatar
Julien Chaumond committed
345
def load_cache(cached_file, tokenizer: PreTrainedTokenizer, max_seq_length):
346
347
    name_to_features = {
        "input_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
Julien Chaumond's avatar
Julien Chaumond committed
348
        "attention_mask": tf.io.FixedLenFeature([max_seq_length], tf.int64),
349
350
        "label_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
    }
Julien Chaumond's avatar
Julien Chaumond committed
351
352
353
    # TODO Find a cleaner way to do this.
    if "token_type_ids" in tokenizer.model_input_names:
        name_to_features["token_type_ids"] = tf.io.FixedLenFeature([max_seq_length], tf.int64)
354
355
356
357

    def _decode_record(record):
        example = tf.io.parse_single_example(record, name_to_features)
        features = {}
358
        features["input_ids"] = example["input_ids"]
Julien Chaumond's avatar
Julien Chaumond committed
359
360
361
        features["attention_mask"] = example["attention_mask"]
        if "token_type_ids" in example:
            features["token_type_ids"] = example["token_type_ids"]
362

363
        return features, example["label_ids"]
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384

    d = tf.data.TFRecordDataset(cached_file)
    d = d.map(_decode_record, num_parallel_calls=4)
    count = d.reduce(0, lambda x, _: x + 1)

    return d, count.numpy()


def save_cache(features, cached_features_file):
    writer = tf.io.TFRecordWriter(cached_features_file)

    for (ex_index, feature) in enumerate(features):
        if ex_index % 5000 == 0:
            logging.info("Writing example %d of %d" % (ex_index, len(features)))

        def create_int_feature(values):
            f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
            return f

        record_feature = collections.OrderedDict()
        record_feature["input_ids"] = create_int_feature(feature.input_ids)
Julien Chaumond's avatar
Julien Chaumond committed
385
386
387
        record_feature["attention_mask"] = create_int_feature(feature.attention_mask)
        if feature.token_type_ids is not None:
            record_feature["token_type_ids"] = create_int_feature(feature.token_type_ids)
388
389
390
391
392
393
394
395
396
397
        record_feature["label_ids"] = create_int_feature(feature.label_ids)

        tf_example = tf.train.Example(features=tf.train.Features(feature=record_feature))

        writer.write(tf_example.SerializeToString())

    writer.close()


def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, batch_size, mode):
398
    drop_remainder = True if args["tpu"] or mode == "train" else False
399
400

    # Load data features from cache or dataset file
401
402
    cached_features_file = os.path.join(
        args["data_dir"],
Julien Chaumond's avatar
Julien Chaumond committed
403
        "cached_{}_{}_{}.tf_record".format(mode, tokenizer.__class__.__name__, str(args["max_seq_length"])),
404
405
    )
    if os.path.exists(cached_features_file) and not args["overwrite_cache"]:
406
        logging.info("Loading features from cached file %s", cached_features_file)
Julien Chaumond's avatar
Julien Chaumond committed
407
        dataset, size = load_cache(cached_features_file, tokenizer, args["max_seq_length"])
408
    else:
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
        logging.info("Creating features from dataset file at %s", args["data_dir"])
        examples = read_examples_from_file(args["data_dir"], mode)
        features = convert_examples_to_features(
            examples,
            labels,
            args["max_seq_length"],
            tokenizer,
            cls_token_at_end=bool(args["model_type"] in ["xlnet"]),
            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=2 if args["model_type"] in ["xlnet"] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=bool(args["model_type"] in ["roberta"]),
            # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=bool(args["model_type"] in ["xlnet"]),
            # pad on the left for xlnet
425
426
            pad_token=tokenizer.pad_token_id,
            pad_token_segment_id=tokenizer.pad_token_type_id,
427
428
            pad_token_label_id=pad_token_label_id,
        )
429
430
        logging.info("Saving features into cached file %s", cached_features_file)
        save_cache(features, cached_features_file)
Julien Chaumond's avatar
Julien Chaumond committed
431
        dataset, size = load_cache(cached_features_file, tokenizer, args["max_seq_length"])
432

433
    if mode == "train":
434
        dataset = dataset.repeat()
435
        dataset = dataset.shuffle(buffer_size=8192, seed=args["seed"])
436
437
438
439
440
441
442
443
444
445
446

    dataset = dataset.batch(batch_size, drop_remainder)
    dataset = dataset.prefetch(buffer_size=batch_size)

    return dataset, size


def main(_):
    logging.set_verbosity(logging.INFO)
    args = flags.FLAGS.flag_values_dict()

447
448
449
450
451
452
    if (
        os.path.exists(args["output_dir"])
        and os.listdir(args["output_dir"])
        and args["do_train"]
        and not args["overwrite_output_dir"]
    ):
453
454
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
455
456
457
                args["output_dir"]
            )
        )
458

459
    if args["fp16"]:
460
461
        tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})

462
463
    if args["tpu"]:
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=args["tpu"])
464
465
466
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.experimental.TPUStrategy(resolver)
467
468
469
470
471
472
        args["n_device"] = args["num_tpu_cores"]
    elif len(args["gpus"].split(",")) > 1:
        args["n_device"] = len([f"/gpu:{gpu}" for gpu in args["gpus"].split(",")])
        strategy = tf.distribute.MirroredStrategy(devices=[f"/gpu:{gpu}" for gpu in args["gpus"].split(",")])
    elif args["no_cuda"]:
        args["n_device"] = 1
473
474
        strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
    else:
475
476
        args["n_device"] = len(args["gpus"].split(","))
        strategy = tf.distribute.OneDeviceStrategy(device="/gpu:" + args["gpus"].split(",")[0])
477

478
479
480
481
482
483
    logging.warning(
        "n_device: %s, distributed training: %s, 16-bits training: %s",
        args["n_device"],
        bool(args["n_device"] > 1),
        args["fp16"],
    )
484

485
    labels = get_labels(args["labels"])
Julien Plu's avatar
Julien Plu committed
486
487
    num_labels = len(labels)
    pad_token_label_id = -1
488
    config = AutoConfig.from_pretrained(
489
490
        args["config_name"] if args["config_name"] else args["model_name_or_path"],
        num_labels=num_labels,
Julien Chaumond's avatar
Julien Chaumond committed
491
        cache_dir=args["cache_dir"],
492
    )
493
494

    logging.info("Training/evaluation parameters %s", args)
Julien Chaumond's avatar
Julien Chaumond committed
495
    args["model_type"] = config.model_type
496
497

    # Training
498
    if args["do_train"]:
499
        tokenizer = AutoTokenizer.from_pretrained(
500
501
            args["tokenizer_name"] if args["tokenizer_name"] else args["model_name_or_path"],
            do_lower_case=args["do_lower_case"],
Julien Chaumond's avatar
Julien Chaumond committed
502
            cache_dir=args["cache_dir"],
503
        )
504
505

        with strategy.scope():
506
            model = TFAutoModelForTokenClassification.from_pretrained(
507
508
509
                args["model_name_or_path"],
                from_pt=bool(".bin" in args["model_name_or_path"]),
                config=config,
Julien Chaumond's avatar
Julien Chaumond committed
510
                cache_dir=args["cache_dir"],
511
            )
512

513
514
515
516
        train_batch_size = args["per_device_train_batch_size"] * args["n_device"]
        train_dataset, num_train_examples = load_and_cache_examples(
            args, tokenizer, labels, pad_token_label_id, train_batch_size, mode="train"
        )
517
        train_dataset = strategy.experimental_distribute_dataset(train_dataset)
518
519
520
521
522
523
524
525
526
527
528
529
        train(
            args,
            strategy,
            train_dataset,
            tokenizer,
            model,
            num_train_examples,
            labels,
            train_batch_size,
            pad_token_label_id,
        )

Julien Chaumond's avatar
Julien Chaumond committed
530
        os.makedirs(args["output_dir"], exist_ok=True)
531
532
533
534
535

        logging.info("Saving model to %s", args["output_dir"])

        model.save_pretrained(args["output_dir"])
        tokenizer.save_pretrained(args["output_dir"])
536
537

    # Evaluation
538
    if args["do_eval"]:
539
        tokenizer = AutoTokenizer.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
540
541
542
        checkpoints = []
        results = []

543
544
545
546
547
548
549
550
551
        if args["eval_all_checkpoints"]:
            checkpoints = list(
                os.path.dirname(c)
                for c in sorted(
                    glob.glob(args["output_dir"] + "/**/" + TF2_WEIGHTS_NAME, recursive=True),
                    key=lambda f: int("".join(filter(str.isdigit, f)) or -1),
                )
            )

552
        logging.info("Evaluate the following checkpoints: %s", checkpoints)
553
554

        if len(checkpoints) == 0:
555
556
            checkpoints.append(args["output_dir"])

557
558
559
560
        for checkpoint in checkpoints:
            global_step = checkpoint.split("-")[-1] if re.match(".*checkpoint-[0-9]", checkpoint) else "final"

            with strategy.scope():
561
                model = TFAutoModelForTokenClassification.from_pretrained(checkpoint)
562

563
564
565
            y_true, y_pred, eval_loss = evaluate(
                args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev"
            )
566
567
568
569
570
            report = metrics.classification_report(y_true, y_pred, digits=4)

            if global_step:
                results.append({global_step + "_report": report, global_step + "_loss": eval_loss})

571
572
        output_eval_file = os.path.join(args["output_dir"], "eval_results.txt")

573
574
575
576
577
578
579
580
581
582
583
584
585
586
        with tf.io.gfile.GFile(output_eval_file, "w") as writer:
            for res in results:
                for key, val in res.items():
                    if "loss" in key:
                        logging.info(key + " = " + str(val))
                        writer.write(key + " = " + str(val))
                        writer.write("\n")
                    else:
                        logging.info(key)
                        logging.info("\n" + report)
                        writer.write(key + "\n")
                        writer.write(report)
                        writer.write("\n")

587
    if args["do_predict"]:
588
589
        tokenizer = AutoTokenizer.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
        model = TFAutoModelForTokenClassification.from_pretrained(args["output_dir"])
590
591
592
593
        eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"]
        predict_dataset, _ = load_and_cache_examples(
            args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode="test"
        )
594
        y_true, y_pred, pred_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="test")
595
596
        output_test_results_file = os.path.join(args["output_dir"], "test_results.txt")
        output_test_predictions_file = os.path.join(args["output_dir"], "test_predictions.txt")
597
598
599
600
        report = metrics.classification_report(y_true, y_pred, digits=4)

        with tf.io.gfile.GFile(output_test_results_file, "w") as writer:
            report = metrics.classification_report(y_true, y_pred, digits=4)
601

602
            logging.info("\n" + report)
603

604
605
606
607
            writer.write(report)
            writer.write("\n\nloss = " + str(pred_loss))

        with tf.io.gfile.GFile(output_test_predictions_file, "w") as writer:
608
            with tf.io.gfile.GFile(os.path.join(args["data_dir"], "test.txt"), "r") as f:
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
                example_id = 0

                for line in f:
                    if line.startswith("-DOCSTART-") or line == "" or line == "\n":
                        writer.write(line)

                        if not y_pred[example_id]:
                            example_id += 1
                    elif y_pred[example_id]:
                        output_line = line.split()[0] + " " + y_pred[example_id].pop(0) + "\n"
                        writer.write(output_line)
                    else:
                        logging.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0])


if __name__ == "__main__":
    flags.mark_flag_as_required("data_dir")
    flags.mark_flag_as_required("output_dir")
    flags.mark_flag_as_required("model_name_or_path")
    app.run(main)