"src/vscode:/vscode.git/clone" did not exist on "fa5918ad132ac2d9552a21ce5623dd8c97324a66"
run_tf_ner.py 25.7 KB
Newer Older
1
# coding=utf-8
Aymeric Augustin's avatar
Aymeric Augustin committed
2
import collections
3
4
import datetime
import glob
Aymeric Augustin's avatar
Aymeric Augustin committed
5
6
import math
import os
7
import re
Aymeric Augustin's avatar
Aymeric Augustin committed
8

9
import numpy as np
Aymeric Augustin's avatar
Aymeric Augustin committed
10
11
import tensorflow as tf
from absl import app, flags, logging
12
from seqeval import metrics
13

Aymeric Augustin's avatar
Aymeric Augustin committed
14
15
from transformers import (
    TF2_WEIGHTS_NAME,
16
17
18
    TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
    AutoConfig,
    AutoTokenizer,
Aymeric Augustin's avatar
Aymeric Augustin committed
19
    GradientAccumulator,
20
    TFAutoModelForTokenClassification,
Aymeric Augustin's avatar
Aymeric Augustin committed
21
22
    create_optimizer,
)
23
24
25
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file


26
27
28
29
30
31
try:
    from fastprogress import master_bar, progress_bar
except ImportError:
    from fastprogress.fastprogress import master_bar, progress_bar


32
33
MODEL_CONFIG_CLASSES = list(TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
34

35
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),)
36
37
38


flags.DEFINE_string(
39
40
    "data_dir", None, "The input data dir. Should contain the .conll files (or other data files) " "for the task."
)
41

42
flags.DEFINE_string("model_type", None, "Model type selected in the list: " + ", ".join(MODEL_TYPES))
43
44

flags.DEFINE_string(
45
46
47
48
    "model_name_or_path",
    None,
    "Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
)
49

50
flags.DEFINE_string("output_dir", None, "The output directory where the model checkpoints will be written.")
51
52

flags.DEFINE_string(
53
54
    "labels", "", "Path to a file containing all labels. If not specified, CoNLL-2003 labels are used."
)
55

56
flags.DEFINE_string("config_name", "", "Pretrained config name or path if not the same as model_name")
57

58
flags.DEFINE_string("tokenizer_name", "", "Pretrained tokenizer name or path if not the same as model_name")
59

60
flags.DEFINE_string("cache_dir", "", "Where do you want to store the pre-trained models downloaded from s3")
61
62

flags.DEFINE_integer(
63
64
    "max_seq_length",
    128,
65
66
    "The maximum total input sentence length after tokenization. "
    "Sequences longer than this will be truncated, sequences shorter "
67
68
    "will be padded.",
)
69
70

flags.DEFINE_string(
71
72
    "tpu",
    None,
73
74
    "The Cloud TPU to use for training. This should be either the name "
    "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
75
76
    "url.",
)
77

78
flags.DEFINE_integer("num_tpu_cores", 8, "Total number of TPU cores to use.")
79

80
flags.DEFINE_boolean("do_train", False, "Whether to run training.")
81

82
flags.DEFINE_boolean("do_eval", False, "Whether to run eval on the dev set.")
83

84
flags.DEFINE_boolean("do_predict", False, "Whether to run predictions on the test set.")
85
86

flags.DEFINE_boolean(
87
88
    "evaluate_during_training", False, "Whether to run evaluation during training at each logging step."
)
89

90
flags.DEFINE_boolean("do_lower_case", False, "Set this flag if you are using an uncased model.")
91

92
flags.DEFINE_integer("per_device_train_batch_size", 8, "Batch size per GPU/CPU/TPU for training.")
93

94
flags.DEFINE_integer("per_device_eval_batch_size", 8, "Batch size per GPU/CPU/TPU for evaluation.")
95
96

flags.DEFINE_integer(
97
98
    "gradient_accumulation_steps", 1, "Number of updates steps to accumulate before performing a backward/update pass."
)
99

100
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
101

102
flags.DEFINE_float("weight_decay", 0.0, "Weight decay if we apply some.")
103

104
flags.DEFINE_float("adam_epsilon", 1e-8, "Epsilon for Adam optimizer.")
105

106
flags.DEFINE_float("max_grad_norm", 1.0, "Max gradient norm.")
107

108
flags.DEFINE_integer("num_train_epochs", 3, "Total number of training epochs to perform.")
109
110

flags.DEFINE_integer(
111
112
    "max_steps", -1, "If > 0: set total number of training steps to perform. Override num_train_epochs."
)
113

114
flags.DEFINE_integer("warmup_steps", 0, "Linear warmup over warmup_steps.")
115

116
flags.DEFINE_integer("logging_steps", 50, "Log every X updates steps.")
117

118
flags.DEFINE_integer("save_steps", 50, "Save checkpoint every X updates steps.")
119
120

flags.DEFINE_boolean(
121
122
123
124
    "eval_all_checkpoints",
    False,
    "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
)
125

126
flags.DEFINE_boolean("no_cuda", False, "Avoid using CUDA when available")
127

128
flags.DEFINE_boolean("overwrite_output_dir", False, "Overwrite the content of the output directory")
129

130
flags.DEFINE_boolean("overwrite_cache", False, "Overwrite the cached training and evaluation sets")
131

132
flags.DEFINE_integer("seed", 42, "random seed for initialization")
133

134
flags.DEFINE_boolean("fp16", False, "Whether to use 16-bit (mixed) precision instead of 32-bit")
135
136

flags.DEFINE_string(
137
138
    "gpus",
    "0",
139
    "Comma separated list of gpus devices. If only one, switch to single "
140
141
    "gpu strategy, if None takes all the gpus available.",
)
142
143


144
145
146
147
148
149
def train(
    args, strategy, train_dataset, tokenizer, model, num_train_examples, labels, train_batch_size, pad_token_label_id
):
    if args["max_steps"] > 0:
        num_train_steps = args["max_steps"] * args["gradient_accumulation_steps"]
        args["num_train_epochs"] = 1
150
    else:
151
152
153
154
155
        num_train_steps = (
            math.ceil(num_train_examples / train_batch_size)
            // args["gradient_accumulation_steps"]
            * args["num_train_epochs"]
        )
156
157
158
159
160

    writer = tf.summary.create_file_writer("/tmp/mylogs")

    with strategy.scope():
        loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
161
        optimizer = create_optimizer(args["learning_rate"], num_train_steps, args["warmup_steps"])
162

163
164
        if args["fp16"]:
            optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(optimizer, "dynamic")
165

166
        loss_metric = tf.keras.metrics.Mean(name="loss", dtype=tf.float32)
167
        gradient_accumulator = GradientAccumulator()
168

169
170
    logging.info("***** Running training *****")
    logging.info("  Num examples = %d", num_train_examples)
171
172
173
174
175
176
177
    logging.info("  Num Epochs = %d", args["num_train_epochs"])
    logging.info("  Instantaneous batch size per device = %d", args["per_device_train_batch_size"])
    logging.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        train_batch_size * args["gradient_accumulation_steps"],
    )
    logging.info("  Gradient Accumulation steps = %d", args["gradient_accumulation_steps"])
178
179
180
181
182
183
184
185
186
187
    logging.info("  Total training steps = %d", num_train_steps)

    model.summary()

    @tf.function
    def apply_gradients():
        grads_and_vars = []

        for gradient, variable in zip(gradient_accumulator.gradients, model.trainable_variables):
            if gradient is not None:
188
                scaled_gradient = gradient / (args["n_device"] * args["gradient_accumulation_steps"])
189
190
191
192
                grads_and_vars.append((scaled_gradient, variable))
            else:
                grads_and_vars.append((gradient, variable))

193
        optimizer.apply_gradients(grads_and_vars, args["max_grad_norm"])
194
195
196
197
198
        gradient_accumulator.reset()

    @tf.function
    def train_step(train_features, train_labels):
        def step_fn(train_features, train_labels):
199
            inputs = {"attention_mask": train_features["input_mask"], "training": True}
200

201
202
203
204
            if args["model_type"] != "distilbert":
                inputs["token_type_ids"] = (
                    train_features["segment_ids"] if args["model_type"] in ["bert", "xlnet"] else None
                )
205
206

            with tf.GradientTape() as tape:
207
                logits = model(train_features["input_ids"], **inputs)[0]
208
                logits = tf.reshape(logits, (-1, len(labels) + 1))
209
                active_loss = tf.reshape(train_features["input_mask"], (-1,))
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
                active_logits = tf.boolean_mask(logits, active_loss)
                train_labels = tf.reshape(train_labels, (-1,))
                active_labels = tf.boolean_mask(train_labels, active_loss)
                cross_entropy = loss_fct(active_labels, active_logits)
                loss = tf.reduce_sum(cross_entropy) * (1.0 / train_batch_size)
                grads = tape.gradient(loss, model.trainable_variables)

                gradient_accumulator(grads)

            return cross_entropy

        per_example_losses = strategy.experimental_run_v2(step_fn, args=(train_features, train_labels))
        mean_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_example_losses, axis=0)

        return mean_loss

    current_time = datetime.datetime.now()
227
    train_iterator = master_bar(range(args["num_train_epochs"]))
228
229
230
231
    global_step = 0
    logging_loss = 0.0

    for epoch in train_iterator:
232
233
234
        epoch_iterator = progress_bar(
            train_dataset, total=num_train_steps, parent=train_iterator, display=args["n_device"] > 1
        )
235
236
237
238
239
240
        step = 1

        with strategy.scope():
            for train_features, train_labels in epoch_iterator:
                loss = train_step(train_features, train_labels)

241
                if step % args["gradient_accumulation_steps"] == 0:
242
243
244
245
246
247
                    strategy.experimental_run_v2(apply_gradients)

                    loss_metric(loss)

                    global_step += 1

248
                    if args["logging_steps"] > 0 and global_step % args["logging_steps"] == 0:
249
                        # Log metrics
250
251
252
253
254
255
                        if (
                            args["n_device"] == 1 and args["evaluate_during_training"]
                        ):  # Only evaluate when single GPU otherwise metrics may not average well
                            y_true, y_pred, eval_loss = evaluate(
                                args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev"
                            )
256
                            report = metrics.classification_report(y_true, y_pred, digits=4)
257

258
259
                            logging.info("Eval at step " + str(global_step) + "\n" + report)
                            logging.info("eval_loss: " + str(eval_loss))
260

261
262
263
264
265
266
267
268
269
                            precision = metrics.precision_score(y_true, y_pred)
                            recall = metrics.recall_score(y_true, y_pred)
                            f1 = metrics.f1_score(y_true, y_pred)

                            with writer.as_default():
                                tf.summary.scalar("eval_loss", eval_loss, global_step)
                                tf.summary.scalar("precision", precision, global_step)
                                tf.summary.scalar("recall", recall, global_step)
                                tf.summary.scalar("f1", f1, global_step)
270

271
272
273
274
275
                        lr = optimizer.learning_rate
                        learning_rate = lr(step)

                        with writer.as_default():
                            tf.summary.scalar("lr", learning_rate, global_step)
276
277
278
279
                            tf.summary.scalar(
                                "loss", (loss_metric.result() - logging_loss) / args["logging_steps"], global_step
                            )

280
281
282
283
284
                        logging_loss = loss_metric.result()

                    with writer.as_default():
                        tf.summary.scalar("loss", loss_metric.result(), step=step)

285
                    if args["save_steps"] > 0 and global_step % args["save_steps"] == 0:
286
                        # Save model checkpoint
287
                        output_dir = os.path.join(args["output_dir"], "checkpoint-{}".format(global_step))
288
289
290

                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
291

292
293
                        model.save_pretrained(output_dir)
                        logging.info("Saving model checkpoint to %s", output_dir)
294
295

                train_iterator.child.comment = f"loss : {loss_metric.result()}"
296
297
                step += 1

298
        train_iterator.write(f"loss epoch {epoch + 1}: {loss_metric.result()}")
299
300
301
302
303
304
305

        loss_metric.reset_states()

    logging.info("  Training took time = {}".format(datetime.datetime.now() - current_time))


def evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode):
306
307
308
309
    eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"]
    eval_dataset, size = load_and_cache_examples(
        args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode=mode
    )
310
311
312
313
    eval_dataset = strategy.experimental_distribute_dataset(eval_dataset)
    preds = None
    num_eval_steps = math.ceil(size / eval_batch_size)
    master = master_bar(range(1))
314
    eval_iterator = progress_bar(eval_dataset, total=num_eval_steps, parent=master, display=args["n_device"] > 1)
315
316
317
318
319
320
321
322
    loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
    loss = 0.0

    logging.info("***** Running evaluation *****")
    logging.info("  Num examples = %d", size)
    logging.info("  Batch size = %d", eval_batch_size)

    for eval_features, eval_labels in eval_iterator:
323
        inputs = {"attention_mask": eval_features["input_mask"], "training": False}
324

325
326
327
328
        if args["model_type"] != "distilbert":
            inputs["token_type_ids"] = (
                eval_features["segment_ids"] if args["model_type"] in ["bert", "xlnet"] else None
            )
329
330

        with strategy.scope():
331
            logits = model(eval_features["input_ids"], **inputs)[0]
332
            tmp_logits = tf.reshape(logits, (-1, len(labels) + 1))
333
            active_loss = tf.reshape(eval_features["input_mask"], (-1,))
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
            active_logits = tf.boolean_mask(tmp_logits, active_loss)
            tmp_eval_labels = tf.reshape(eval_labels, (-1,))
            active_labels = tf.boolean_mask(tmp_eval_labels, active_loss)
            cross_entropy = loss_fct(active_labels, active_logits)
            loss += tf.reduce_sum(cross_entropy) * (1.0 / eval_batch_size)

        if preds is None:
            preds = logits.numpy()
            label_ids = eval_labels.numpy()
        else:
            preds = np.append(preds, logits.numpy(), axis=0)
            label_ids = np.append(label_ids, eval_labels.numpy(), axis=0)

    preds = np.argmax(preds, axis=2)
    y_pred = [[] for _ in range(label_ids.shape[0])]
    y_true = [[] for _ in range(label_ids.shape[0])]
    loss = loss / num_eval_steps

    for i in range(label_ids.shape[0]):
        for j in range(label_ids.shape[1]):
            if label_ids[i, j] != pad_token_label_id:
                y_pred[i].append(labels[preds[i, j] - 1])
                y_true[i].append(labels[label_ids[i, j] - 1])

    return y_true, y_pred, loss.numpy()


def load_cache(cached_file, max_seq_length):
    name_to_features = {
        "input_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
        "input_mask": tf.io.FixedLenFeature([max_seq_length], tf.int64),
        "segment_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
        "label_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
    }

    def _decode_record(record):
        example = tf.io.parse_single_example(record, name_to_features)
        features = {}
372
373
374
        features["input_ids"] = example["input_ids"]
        features["input_mask"] = example["input_mask"]
        features["segment_ids"] = example["segment_ids"]
375

376
        return features, example["label_ids"]
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409

    d = tf.data.TFRecordDataset(cached_file)
    d = d.map(_decode_record, num_parallel_calls=4)
    count = d.reduce(0, lambda x, _: x + 1)

    return d, count.numpy()


def save_cache(features, cached_features_file):
    writer = tf.io.TFRecordWriter(cached_features_file)

    for (ex_index, feature) in enumerate(features):
        if ex_index % 5000 == 0:
            logging.info("Writing example %d of %d" % (ex_index, len(features)))

        def create_int_feature(values):
            f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
            return f

        record_feature = collections.OrderedDict()
        record_feature["input_ids"] = create_int_feature(feature.input_ids)
        record_feature["input_mask"] = create_int_feature(feature.input_mask)
        record_feature["segment_ids"] = create_int_feature(feature.segment_ids)
        record_feature["label_ids"] = create_int_feature(feature.label_ids)

        tf_example = tf.train.Example(features=tf.train.Features(feature=record_feature))

        writer.write(tf_example.SerializeToString())

    writer.close()


def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, batch_size, mode):
410
    drop_remainder = True if args["tpu"] or mode == "train" else False
411
412

    # Load data features from cache or dataset file
413
414
415
416
417
418
419
    cached_features_file = os.path.join(
        args["data_dir"],
        "cached_{}_{}_{}.tf_record".format(
            mode, list(filter(None, args["model_name_or_path"].split("/"))).pop(), str(args["max_seq_length"])
        ),
    )
    if os.path.exists(cached_features_file) and not args["overwrite_cache"]:
420
        logging.info("Loading features from cached file %s", cached_features_file)
421
        dataset, size = load_cache(cached_features_file, args["max_seq_length"])
422
    else:
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
        logging.info("Creating features from dataset file at %s", args["data_dir"])
        examples = read_examples_from_file(args["data_dir"], mode)
        features = convert_examples_to_features(
            examples,
            labels,
            args["max_seq_length"],
            tokenizer,
            cls_token_at_end=bool(args["model_type"] in ["xlnet"]),
            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=2 if args["model_type"] in ["xlnet"] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=bool(args["model_type"] in ["roberta"]),
            # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=bool(args["model_type"] in ["xlnet"]),
            # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
            pad_token_segment_id=4 if args["model_type"] in ["xlnet"] else 0,
            pad_token_label_id=pad_token_label_id,
        )
443
444
        logging.info("Saving features into cached file %s", cached_features_file)
        save_cache(features, cached_features_file)
445
        dataset, size = load_cache(cached_features_file, args["max_seq_length"])
446

447
    if mode == "train":
448
        dataset = dataset.repeat()
449
        dataset = dataset.shuffle(buffer_size=8192, seed=args["seed"])
450
451
452
453
454
455
456
457
458
459
460

    dataset = dataset.batch(batch_size, drop_remainder)
    dataset = dataset.prefetch(buffer_size=batch_size)

    return dataset, size


def main(_):
    logging.set_verbosity(logging.INFO)
    args = flags.FLAGS.flag_values_dict()

461
462
463
464
465
466
    if (
        os.path.exists(args["output_dir"])
        and os.listdir(args["output_dir"])
        and args["do_train"]
        and not args["overwrite_output_dir"]
    ):
467
468
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
469
470
471
                args["output_dir"]
            )
        )
472

473
    if args["fp16"]:
474
475
        tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})

476
477
    if args["tpu"]:
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=args["tpu"])
478
479
480
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.experimental.TPUStrategy(resolver)
481
482
483
484
485
486
        args["n_device"] = args["num_tpu_cores"]
    elif len(args["gpus"].split(",")) > 1:
        args["n_device"] = len([f"/gpu:{gpu}" for gpu in args["gpus"].split(",")])
        strategy = tf.distribute.MirroredStrategy(devices=[f"/gpu:{gpu}" for gpu in args["gpus"].split(",")])
    elif args["no_cuda"]:
        args["n_device"] = 1
487
488
        strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
    else:
489
490
        args["n_device"] = len(args["gpus"].split(","))
        strategy = tf.distribute.OneDeviceStrategy(device="/gpu:" + args["gpus"].split(",")[0])
491

492
493
494
495
496
497
    logging.warning(
        "n_device: %s, distributed training: %s, 16-bits training: %s",
        args["n_device"],
        bool(args["n_device"] > 1),
        args["fp16"],
    )
498

499
    labels = get_labels(args["labels"])
500
501
    num_labels = len(labels) + 1
    pad_token_label_id = 0
502
    config = AutoConfig.from_pretrained(
503
504
505
506
        args["config_name"] if args["config_name"] else args["model_name_or_path"],
        num_labels=num_labels,
        cache_dir=args["cache_dir"] if args["cache_dir"] else None,
    )
507
508
509
510

    logging.info("Training/evaluation parameters %s", args)

    # Training
511
    if args["do_train"]:
512
        tokenizer = AutoTokenizer.from_pretrained(
513
514
515
516
            args["tokenizer_name"] if args["tokenizer_name"] else args["model_name_or_path"],
            do_lower_case=args["do_lower_case"],
            cache_dir=args["cache_dir"] if args["cache_dir"] else None,
        )
517
518

        with strategy.scope():
519
            model = TFAutoModelForTokenClassification.from_pretrained(
520
521
522
523
524
                args["model_name_or_path"],
                from_pt=bool(".bin" in args["model_name_or_path"]),
                config=config,
                cache_dir=args["cache_dir"] if args["cache_dir"] else None,
            )
525
526
            model.layers[-1].activation = tf.keras.activations.softmax

527
528
529
530
        train_batch_size = args["per_device_train_batch_size"] * args["n_device"]
        train_dataset, num_train_examples = load_and_cache_examples(
            args, tokenizer, labels, pad_token_label_id, train_batch_size, mode="train"
        )
531
        train_dataset = strategy.experimental_distribute_dataset(train_dataset)
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
        train(
            args,
            strategy,
            train_dataset,
            tokenizer,
            model,
            num_train_examples,
            labels,
            train_batch_size,
            pad_token_label_id,
        )

        if not os.path.exists(args["output_dir"]):
            os.makedirs(args["output_dir"])

        logging.info("Saving model to %s", args["output_dir"])

        model.save_pretrained(args["output_dir"])
        tokenizer.save_pretrained(args["output_dir"])
551
552

    # Evaluation
553
    if args["do_eval"]:
554
        tokenizer = AutoTokenizer.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
555
556
557
        checkpoints = []
        results = []

558
559
560
561
562
563
564
565
566
        if args["eval_all_checkpoints"]:
            checkpoints = list(
                os.path.dirname(c)
                for c in sorted(
                    glob.glob(args["output_dir"] + "/**/" + TF2_WEIGHTS_NAME, recursive=True),
                    key=lambda f: int("".join(filter(str.isdigit, f)) or -1),
                )
            )

567
        logging.info("Evaluate the following checkpoints: %s", checkpoints)
568
569

        if len(checkpoints) == 0:
570
571
            checkpoints.append(args["output_dir"])

572
573
574
575
        for checkpoint in checkpoints:
            global_step = checkpoint.split("-")[-1] if re.match(".*checkpoint-[0-9]", checkpoint) else "final"

            with strategy.scope():
576
                model = TFAutoModelForTokenClassification.from_pretrained(checkpoint)
577

578
579
580
            y_true, y_pred, eval_loss = evaluate(
                args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev"
            )
581
582
583
584
585
            report = metrics.classification_report(y_true, y_pred, digits=4)

            if global_step:
                results.append({global_step + "_report": report, global_step + "_loss": eval_loss})

586
587
        output_eval_file = os.path.join(args["output_dir"], "eval_results.txt")

588
589
590
591
592
593
594
595
596
597
598
599
600
601
        with tf.io.gfile.GFile(output_eval_file, "w") as writer:
            for res in results:
                for key, val in res.items():
                    if "loss" in key:
                        logging.info(key + " = " + str(val))
                        writer.write(key + " = " + str(val))
                        writer.write("\n")
                    else:
                        logging.info(key)
                        logging.info("\n" + report)
                        writer.write(key + "\n")
                        writer.write(report)
                        writer.write("\n")

602
    if args["do_predict"]:
603
604
        tokenizer = AutoTokenizer.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
        model = TFAutoModelForTokenClassification.from_pretrained(args["output_dir"])
605
606
607
608
        eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"]
        predict_dataset, _ = load_and_cache_examples(
            args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode="test"
        )
609
        y_true, y_pred, pred_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="test")
610
611
        output_test_results_file = os.path.join(args["output_dir"], "test_results.txt")
        output_test_predictions_file = os.path.join(args["output_dir"], "test_predictions.txt")
612
613
614
615
        report = metrics.classification_report(y_true, y_pred, digits=4)

        with tf.io.gfile.GFile(output_test_results_file, "w") as writer:
            report = metrics.classification_report(y_true, y_pred, digits=4)
616

617
            logging.info("\n" + report)
618

619
620
621
622
            writer.write(report)
            writer.write("\n\nloss = " + str(pred_loss))

        with tf.io.gfile.GFile(output_test_predictions_file, "w") as writer:
623
            with tf.io.gfile.GFile(os.path.join(args["data_dir"], "test.txt"), "r") as f:
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
                example_id = 0

                for line in f:
                    if line.startswith("-DOCSTART-") or line == "" or line == "\n":
                        writer.write(line)

                        if not y_pred[example_id]:
                            example_id += 1
                    elif y_pred[example_id]:
                        output_line = line.split()[0] + " " + y_pred[example_id].pop(0) + "\n"
                        writer.write(output_line)
                    else:
                        logging.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0])


if __name__ == "__main__":
    flags.mark_flag_as_required("data_dir")
    flags.mark_flag_as_required("output_dir")
    flags.mark_flag_as_required("model_name_or_path")
    flags.mark_flag_as_required("model_type")
    app.run(main)