"ts/nni_manager/vscode:/vscode.git/clone" did not exist on "80bc9537b6952a879e3a7805996cb3c862a29ee8"
run_tf_ner.py 25.9 KB
Newer Older
1
# coding=utf-8
Aymeric Augustin's avatar
Aymeric Augustin committed
2
import collections
3
4
import datetime
import glob
Aymeric Augustin's avatar
Aymeric Augustin committed
5
6
import math
import os
7
import re
Aymeric Augustin's avatar
Aymeric Augustin committed
8

9
import numpy as np
Aymeric Augustin's avatar
Aymeric Augustin committed
10
11
12
13
import tensorflow as tf
from absl import app, flags, logging

from fastprogress import master_bar, progress_bar
14
from seqeval import metrics
Aymeric Augustin's avatar
Aymeric Augustin committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from transformers import (
    TF2_WEIGHTS_NAME,
    BertConfig,
    BertTokenizer,
    DistilBertConfig,
    DistilBertTokenizer,
    GradientAccumulator,
    RobertaConfig,
    RobertaTokenizer,
    TFBertForTokenClassification,
    TFDistilBertForTokenClassification,
    TFRobertaForTokenClassification,
    create_optimizer,
)
29
30
31
32
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file


ALL_MODELS = sum(
33
34
    (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)), ()
)
35
36
37
38

MODEL_CLASSES = {
    "bert": (BertConfig, TFBertForTokenClassification, BertTokenizer),
    "roberta": (RobertaConfig, TFRobertaForTokenClassification, RobertaTokenizer),
39
    "distilbert": (DistilBertConfig, TFDistilBertForTokenClassification, DistilBertTokenizer),
40
41
42
43
}


flags.DEFINE_string(
44
45
    "data_dir", None, "The input data dir. Should contain the .conll files (or other data files) " "for the task."
)
46

47
flags.DEFINE_string("model_type", None, "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
48
49

flags.DEFINE_string(
50
51
52
53
    "model_name_or_path",
    None,
    "Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
)
54

55
flags.DEFINE_string("output_dir", None, "The output directory where the model checkpoints will be written.")
56
57

flags.DEFINE_string(
58
59
    "labels", "", "Path to a file containing all labels. If not specified, CoNLL-2003 labels are used."
)
60

61
flags.DEFINE_string("config_name", "", "Pretrained config name or path if not the same as model_name")
62

63
flags.DEFINE_string("tokenizer_name", "", "Pretrained tokenizer name or path if not the same as model_name")
64

65
flags.DEFINE_string("cache_dir", "", "Where do you want to store the pre-trained models downloaded from s3")
66
67

flags.DEFINE_integer(
68
69
    "max_seq_length",
    128,
70
71
    "The maximum total input sentence length after tokenization. "
    "Sequences longer than this will be truncated, sequences shorter "
72
73
    "will be padded.",
)
74
75

flags.DEFINE_string(
76
77
    "tpu",
    None,
78
79
    "The Cloud TPU to use for training. This should be either the name "
    "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
80
81
    "url.",
)
82

83
flags.DEFINE_integer("num_tpu_cores", 8, "Total number of TPU cores to use.")
84

85
flags.DEFINE_boolean("do_train", False, "Whether to run training.")
86

87
flags.DEFINE_boolean("do_eval", False, "Whether to run eval on the dev set.")
88

89
flags.DEFINE_boolean("do_predict", False, "Whether to run predictions on the test set.")
90
91

flags.DEFINE_boolean(
92
93
    "evaluate_during_training", False, "Whether to run evaluation during training at each logging step."
)
94

95
flags.DEFINE_boolean("do_lower_case", False, "Set this flag if you are using an uncased model.")
96

97
flags.DEFINE_integer("per_device_train_batch_size", 8, "Batch size per GPU/CPU/TPU for training.")
98

99
flags.DEFINE_integer("per_device_eval_batch_size", 8, "Batch size per GPU/CPU/TPU for evaluation.")
100
101

flags.DEFINE_integer(
102
103
    "gradient_accumulation_steps", 1, "Number of updates steps to accumulate before performing a backward/update pass."
)
104

105
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
106

107
flags.DEFINE_float("weight_decay", 0.0, "Weight decay if we apply some.")
108

109
flags.DEFINE_float("adam_epsilon", 1e-8, "Epsilon for Adam optimizer.")
110

111
flags.DEFINE_float("max_grad_norm", 1.0, "Max gradient norm.")
112

113
flags.DEFINE_integer("num_train_epochs", 3, "Total number of training epochs to perform.")
114
115

flags.DEFINE_integer(
116
117
    "max_steps", -1, "If > 0: set total number of training steps to perform. Override num_train_epochs."
)
118

119
flags.DEFINE_integer("warmup_steps", 0, "Linear warmup over warmup_steps.")
120

121
flags.DEFINE_integer("logging_steps", 50, "Log every X updates steps.")
122

123
flags.DEFINE_integer("save_steps", 50, "Save checkpoint every X updates steps.")
124
125

flags.DEFINE_boolean(
126
127
128
129
    "eval_all_checkpoints",
    False,
    "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
)
130

131
flags.DEFINE_boolean("no_cuda", False, "Avoid using CUDA when available")
132

133
flags.DEFINE_boolean("overwrite_output_dir", False, "Overwrite the content of the output directory")
134

135
flags.DEFINE_boolean("overwrite_cache", False, "Overwrite the cached training and evaluation sets")
136

137
flags.DEFINE_integer("seed", 42, "random seed for initialization")
138

139
flags.DEFINE_boolean("fp16", False, "Whether to use 16-bit (mixed) precision instead of 32-bit")
140
141

flags.DEFINE_string(
142
143
    "gpus",
    "0",
144
    "Comma separated list of gpus devices. If only one, switch to single "
145
146
    "gpu strategy, if None takes all the gpus available.",
)
147
148


149
150
151
152
153
154
def train(
    args, strategy, train_dataset, tokenizer, model, num_train_examples, labels, train_batch_size, pad_token_label_id
):
    if args["max_steps"] > 0:
        num_train_steps = args["max_steps"] * args["gradient_accumulation_steps"]
        args["num_train_epochs"] = 1
155
    else:
156
157
158
159
160
        num_train_steps = (
            math.ceil(num_train_examples / train_batch_size)
            // args["gradient_accumulation_steps"]
            * args["num_train_epochs"]
        )
161
162
163
164
165

    writer = tf.summary.create_file_writer("/tmp/mylogs")

    with strategy.scope():
        loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
166
        optimizer = create_optimizer(args["learning_rate"], num_train_steps, args["warmup_steps"])
167

168
169
        if args["fp16"]:
            optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(optimizer, "dynamic")
170

171
        loss_metric = tf.keras.metrics.Mean(name="loss", dtype=tf.float32)
172
        gradient_accumulator = GradientAccumulator()
173

174
175
    logging.info("***** Running training *****")
    logging.info("  Num examples = %d", num_train_examples)
176
177
178
179
180
181
182
    logging.info("  Num Epochs = %d", args["num_train_epochs"])
    logging.info("  Instantaneous batch size per device = %d", args["per_device_train_batch_size"])
    logging.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        train_batch_size * args["gradient_accumulation_steps"],
    )
    logging.info("  Gradient Accumulation steps = %d", args["gradient_accumulation_steps"])
183
184
185
186
187
188
189
190
191
192
    logging.info("  Total training steps = %d", num_train_steps)

    model.summary()

    @tf.function
    def apply_gradients():
        grads_and_vars = []

        for gradient, variable in zip(gradient_accumulator.gradients, model.trainable_variables):
            if gradient is not None:
193
                scaled_gradient = gradient / (args["n_device"] * args["gradient_accumulation_steps"])
194
195
196
197
                grads_and_vars.append((scaled_gradient, variable))
            else:
                grads_and_vars.append((gradient, variable))

198
        optimizer.apply_gradients(grads_and_vars, args["max_grad_norm"])
199
200
201
202
203
        gradient_accumulator.reset()

    @tf.function
    def train_step(train_features, train_labels):
        def step_fn(train_features, train_labels):
204
            inputs = {"attention_mask": train_features["input_mask"], "training": True}
205

206
207
208
209
            if args["model_type"] != "distilbert":
                inputs["token_type_ids"] = (
                    train_features["segment_ids"] if args["model_type"] in ["bert", "xlnet"] else None
                )
210
211

            with tf.GradientTape() as tape:
212
                logits = model(train_features["input_ids"], **inputs)[0]
213
                logits = tf.reshape(logits, (-1, len(labels) + 1))
214
                active_loss = tf.reshape(train_features["input_mask"], (-1,))
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
                active_logits = tf.boolean_mask(logits, active_loss)
                train_labels = tf.reshape(train_labels, (-1,))
                active_labels = tf.boolean_mask(train_labels, active_loss)
                cross_entropy = loss_fct(active_labels, active_logits)
                loss = tf.reduce_sum(cross_entropy) * (1.0 / train_batch_size)
                grads = tape.gradient(loss, model.trainable_variables)

                gradient_accumulator(grads)

            return cross_entropy

        per_example_losses = strategy.experimental_run_v2(step_fn, args=(train_features, train_labels))
        mean_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_example_losses, axis=0)

        return mean_loss

    current_time = datetime.datetime.now()
232
    train_iterator = master_bar(range(args["num_train_epochs"]))
233
234
235
236
    global_step = 0
    logging_loss = 0.0

    for epoch in train_iterator:
237
238
239
        epoch_iterator = progress_bar(
            train_dataset, total=num_train_steps, parent=train_iterator, display=args["n_device"] > 1
        )
240
241
242
243
244
245
        step = 1

        with strategy.scope():
            for train_features, train_labels in epoch_iterator:
                loss = train_step(train_features, train_labels)

246
                if step % args["gradient_accumulation_steps"] == 0:
247
248
249
250
251
252
                    strategy.experimental_run_v2(apply_gradients)

                    loss_metric(loss)

                    global_step += 1

253
                    if args["logging_steps"] > 0 and global_step % args["logging_steps"] == 0:
254
                        # Log metrics
255
256
257
258
259
260
                        if (
                            args["n_device"] == 1 and args["evaluate_during_training"]
                        ):  # Only evaluate when single GPU otherwise metrics may not average well
                            y_true, y_pred, eval_loss = evaluate(
                                args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev"
                            )
261
                            report = metrics.classification_report(y_true, y_pred, digits=4)
262

263
264
                            logging.info("Eval at step " + str(global_step) + "\n" + report)
                            logging.info("eval_loss: " + str(eval_loss))
265

266
267
268
269
270
271
272
273
274
                            precision = metrics.precision_score(y_true, y_pred)
                            recall = metrics.recall_score(y_true, y_pred)
                            f1 = metrics.f1_score(y_true, y_pred)

                            with writer.as_default():
                                tf.summary.scalar("eval_loss", eval_loss, global_step)
                                tf.summary.scalar("precision", precision, global_step)
                                tf.summary.scalar("recall", recall, global_step)
                                tf.summary.scalar("f1", f1, global_step)
275

276
277
278
279
280
                        lr = optimizer.learning_rate
                        learning_rate = lr(step)

                        with writer.as_default():
                            tf.summary.scalar("lr", learning_rate, global_step)
281
282
283
284
                            tf.summary.scalar(
                                "loss", (loss_metric.result() - logging_loss) / args["logging_steps"], global_step
                            )

285
286
287
288
289
                        logging_loss = loss_metric.result()

                    with writer.as_default():
                        tf.summary.scalar("loss", loss_metric.result(), step=step)

290
                    if args["save_steps"] > 0 and global_step % args["save_steps"] == 0:
291
                        # Save model checkpoint
292
                        output_dir = os.path.join(args["output_dir"], "checkpoint-{}".format(global_step))
293
294
295

                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
296

297
298
                        model.save_pretrained(output_dir)
                        logging.info("Saving model checkpoint to %s", output_dir)
299
300

                train_iterator.child.comment = f"loss : {loss_metric.result()}"
301
302
                step += 1

303
        train_iterator.write(f"loss epoch {epoch + 1}: {loss_metric.result()}")
304
305
306
307
308
309
310

        loss_metric.reset_states()

    logging.info("  Training took time = {}".format(datetime.datetime.now() - current_time))


def evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode):
311
312
313
314
    eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"]
    eval_dataset, size = load_and_cache_examples(
        args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode=mode
    )
315
316
317
318
    eval_dataset = strategy.experimental_distribute_dataset(eval_dataset)
    preds = None
    num_eval_steps = math.ceil(size / eval_batch_size)
    master = master_bar(range(1))
319
    eval_iterator = progress_bar(eval_dataset, total=num_eval_steps, parent=master, display=args["n_device"] > 1)
320
321
322
323
324
325
326
327
    loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
    loss = 0.0

    logging.info("***** Running evaluation *****")
    logging.info("  Num examples = %d", size)
    logging.info("  Batch size = %d", eval_batch_size)

    for eval_features, eval_labels in eval_iterator:
328
        inputs = {"attention_mask": eval_features["input_mask"], "training": False}
329

330
331
332
333
        if args["model_type"] != "distilbert":
            inputs["token_type_ids"] = (
                eval_features["segment_ids"] if args["model_type"] in ["bert", "xlnet"] else None
            )
334
335

        with strategy.scope():
336
            logits = model(eval_features["input_ids"], **inputs)[0]
337
            tmp_logits = tf.reshape(logits, (-1, len(labels) + 1))
338
            active_loss = tf.reshape(eval_features["input_mask"], (-1,))
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
            active_logits = tf.boolean_mask(tmp_logits, active_loss)
            tmp_eval_labels = tf.reshape(eval_labels, (-1,))
            active_labels = tf.boolean_mask(tmp_eval_labels, active_loss)
            cross_entropy = loss_fct(active_labels, active_logits)
            loss += tf.reduce_sum(cross_entropy) * (1.0 / eval_batch_size)

        if preds is None:
            preds = logits.numpy()
            label_ids = eval_labels.numpy()
        else:
            preds = np.append(preds, logits.numpy(), axis=0)
            label_ids = np.append(label_ids, eval_labels.numpy(), axis=0)

    preds = np.argmax(preds, axis=2)
    y_pred = [[] for _ in range(label_ids.shape[0])]
    y_true = [[] for _ in range(label_ids.shape[0])]
    loss = loss / num_eval_steps

    for i in range(label_ids.shape[0]):
        for j in range(label_ids.shape[1]):
            if label_ids[i, j] != pad_token_label_id:
                y_pred[i].append(labels[preds[i, j] - 1])
                y_true[i].append(labels[label_ids[i, j] - 1])

    return y_true, y_pred, loss.numpy()


def load_cache(cached_file, max_seq_length):
    name_to_features = {
        "input_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
        "input_mask": tf.io.FixedLenFeature([max_seq_length], tf.int64),
        "segment_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
        "label_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
    }

    def _decode_record(record):
        example = tf.io.parse_single_example(record, name_to_features)
        features = {}
377
378
379
        features["input_ids"] = example["input_ids"]
        features["input_mask"] = example["input_mask"]
        features["segment_ids"] = example["segment_ids"]
380

381
        return features, example["label_ids"]
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414

    d = tf.data.TFRecordDataset(cached_file)
    d = d.map(_decode_record, num_parallel_calls=4)
    count = d.reduce(0, lambda x, _: x + 1)

    return d, count.numpy()


def save_cache(features, cached_features_file):
    writer = tf.io.TFRecordWriter(cached_features_file)

    for (ex_index, feature) in enumerate(features):
        if ex_index % 5000 == 0:
            logging.info("Writing example %d of %d" % (ex_index, len(features)))

        def create_int_feature(values):
            f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
            return f

        record_feature = collections.OrderedDict()
        record_feature["input_ids"] = create_int_feature(feature.input_ids)
        record_feature["input_mask"] = create_int_feature(feature.input_mask)
        record_feature["segment_ids"] = create_int_feature(feature.segment_ids)
        record_feature["label_ids"] = create_int_feature(feature.label_ids)

        tf_example = tf.train.Example(features=tf.train.Features(feature=record_feature))

        writer.write(tf_example.SerializeToString())

    writer.close()


def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, batch_size, mode):
415
    drop_remainder = True if args["tpu"] or mode == "train" else False
416
417

    # Load data features from cache or dataset file
418
419
420
421
422
423
424
    cached_features_file = os.path.join(
        args["data_dir"],
        "cached_{}_{}_{}.tf_record".format(
            mode, list(filter(None, args["model_name_or_path"].split("/"))).pop(), str(args["max_seq_length"])
        ),
    )
    if os.path.exists(cached_features_file) and not args["overwrite_cache"]:
425
        logging.info("Loading features from cached file %s", cached_features_file)
426
        dataset, size = load_cache(cached_features_file, args["max_seq_length"])
427
    else:
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
        logging.info("Creating features from dataset file at %s", args["data_dir"])
        examples = read_examples_from_file(args["data_dir"], mode)
        features = convert_examples_to_features(
            examples,
            labels,
            args["max_seq_length"],
            tokenizer,
            cls_token_at_end=bool(args["model_type"] in ["xlnet"]),
            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=2 if args["model_type"] in ["xlnet"] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=bool(args["model_type"] in ["roberta"]),
            # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=bool(args["model_type"] in ["xlnet"]),
            # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
            pad_token_segment_id=4 if args["model_type"] in ["xlnet"] else 0,
            pad_token_label_id=pad_token_label_id,
        )
448
449
        logging.info("Saving features into cached file %s", cached_features_file)
        save_cache(features, cached_features_file)
450
        dataset, size = load_cache(cached_features_file, args["max_seq_length"])
451

452
    if mode == "train":
453
        dataset = dataset.repeat()
454
        dataset = dataset.shuffle(buffer_size=8192, seed=args["seed"])
455
456
457
458
459
460
461
462
463
464
465

    dataset = dataset.batch(batch_size, drop_remainder)
    dataset = dataset.prefetch(buffer_size=batch_size)

    return dataset, size


def main(_):
    logging.set_verbosity(logging.INFO)
    args = flags.FLAGS.flag_values_dict()

466
467
468
469
470
471
    if (
        os.path.exists(args["output_dir"])
        and os.listdir(args["output_dir"])
        and args["do_train"]
        and not args["overwrite_output_dir"]
    ):
472
473
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
474
475
476
                args["output_dir"]
            )
        )
477

478
    if args["fp16"]:
479
480
        tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})

481
482
    if args["tpu"]:
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=args["tpu"])
483
484
485
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.experimental.TPUStrategy(resolver)
486
487
488
489
490
491
        args["n_device"] = args["num_tpu_cores"]
    elif len(args["gpus"].split(",")) > 1:
        args["n_device"] = len([f"/gpu:{gpu}" for gpu in args["gpus"].split(",")])
        strategy = tf.distribute.MirroredStrategy(devices=[f"/gpu:{gpu}" for gpu in args["gpus"].split(",")])
    elif args["no_cuda"]:
        args["n_device"] = 1
492
493
        strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
    else:
494
495
        args["n_device"] = len(args["gpus"].split(","))
        strategy = tf.distribute.OneDeviceStrategy(device="/gpu:" + args["gpus"].split(",")[0])
496

497
498
499
500
501
502
    logging.warning(
        "n_device: %s, distributed training: %s, 16-bits training: %s",
        args["n_device"],
        bool(args["n_device"] > 1),
        args["fp16"],
    )
503

504
    labels = get_labels(args["labels"])
505
506
    num_labels = len(labels) + 1
    pad_token_label_id = 0
507
508
509
510
511
512
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args["model_type"]]
    config = config_class.from_pretrained(
        args["config_name"] if args["config_name"] else args["model_name_or_path"],
        num_labels=num_labels,
        cache_dir=args["cache_dir"] if args["cache_dir"] else None,
    )
513
514
515
516

    logging.info("Training/evaluation parameters %s", args)

    # Training
517
518
519
520
521
522
    if args["do_train"]:
        tokenizer = tokenizer_class.from_pretrained(
            args["tokenizer_name"] if args["tokenizer_name"] else args["model_name_or_path"],
            do_lower_case=args["do_lower_case"],
            cache_dir=args["cache_dir"] if args["cache_dir"] else None,
        )
523
524

        with strategy.scope():
525
526
527
528
529
530
            model = model_class.from_pretrained(
                args["model_name_or_path"],
                from_pt=bool(".bin" in args["model_name_or_path"]),
                config=config,
                cache_dir=args["cache_dir"] if args["cache_dir"] else None,
            )
531
532
            model.layers[-1].activation = tf.keras.activations.softmax

533
534
535
536
        train_batch_size = args["per_device_train_batch_size"] * args["n_device"]
        train_dataset, num_train_examples = load_and_cache_examples(
            args, tokenizer, labels, pad_token_label_id, train_batch_size, mode="train"
        )
537
        train_dataset = strategy.experimental_distribute_dataset(train_dataset)
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
        train(
            args,
            strategy,
            train_dataset,
            tokenizer,
            model,
            num_train_examples,
            labels,
            train_batch_size,
            pad_token_label_id,
        )

        if not os.path.exists(args["output_dir"]):
            os.makedirs(args["output_dir"])

        logging.info("Saving model to %s", args["output_dir"])

        model.save_pretrained(args["output_dir"])
        tokenizer.save_pretrained(args["output_dir"])
557
558

    # Evaluation
559
560
    if args["do_eval"]:
        tokenizer = tokenizer_class.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
561
562
563
        checkpoints = []
        results = []

564
565
566
567
568
569
570
571
572
        if args["eval_all_checkpoints"]:
            checkpoints = list(
                os.path.dirname(c)
                for c in sorted(
                    glob.glob(args["output_dir"] + "/**/" + TF2_WEIGHTS_NAME, recursive=True),
                    key=lambda f: int("".join(filter(str.isdigit, f)) or -1),
                )
            )

573
        logging.info("Evaluate the following checkpoints: %s", checkpoints)
574
575

        if len(checkpoints) == 0:
576
577
            checkpoints.append(args["output_dir"])

578
579
580
581
582
583
        for checkpoint in checkpoints:
            global_step = checkpoint.split("-")[-1] if re.match(".*checkpoint-[0-9]", checkpoint) else "final"

            with strategy.scope():
                model = model_class.from_pretrained(checkpoint)

584
585
586
            y_true, y_pred, eval_loss = evaluate(
                args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev"
            )
587
588
589
590
591
            report = metrics.classification_report(y_true, y_pred, digits=4)

            if global_step:
                results.append({global_step + "_report": report, global_step + "_loss": eval_loss})

592
593
        output_eval_file = os.path.join(args["output_dir"], "eval_results.txt")

594
595
596
597
598
599
600
601
602
603
604
605
606
607
        with tf.io.gfile.GFile(output_eval_file, "w") as writer:
            for res in results:
                for key, val in res.items():
                    if "loss" in key:
                        logging.info(key + " = " + str(val))
                        writer.write(key + " = " + str(val))
                        writer.write("\n")
                    else:
                        logging.info(key)
                        logging.info("\n" + report)
                        writer.write(key + "\n")
                        writer.write(report)
                        writer.write("\n")

608
609
610
611
612
613
614
    if args["do_predict"]:
        tokenizer = tokenizer_class.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
        model = model_class.from_pretrained(args["output_dir"])
        eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"]
        predict_dataset, _ = load_and_cache_examples(
            args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode="test"
        )
615
        y_true, y_pred, pred_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="test")
616
617
        output_test_results_file = os.path.join(args["output_dir"], "test_results.txt")
        output_test_predictions_file = os.path.join(args["output_dir"], "test_predictions.txt")
618
619
620
621
        report = metrics.classification_report(y_true, y_pred, digits=4)

        with tf.io.gfile.GFile(output_test_results_file, "w") as writer:
            report = metrics.classification_report(y_true, y_pred, digits=4)
622

623
            logging.info("\n" + report)
624

625
626
627
628
            writer.write(report)
            writer.write("\n\nloss = " + str(pred_loss))

        with tf.io.gfile.GFile(output_test_predictions_file, "w") as writer:
629
            with tf.io.gfile.GFile(os.path.join(args["data_dir"], "test.txt"), "r") as f:
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
                example_id = 0

                for line in f:
                    if line.startswith("-DOCSTART-") or line == "" or line == "\n":
                        writer.write(line)

                        if not y_pred[example_id]:
                            example_id += 1
                    elif y_pred[example_id]:
                        output_line = line.split()[0] + " " + y_pred[example_id].pop(0) + "\n"
                        writer.write(output_line)
                    else:
                        logging.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0])


if __name__ == "__main__":
    flags.mark_flag_as_required("data_dir")
    flags.mark_flag_as_required("output_dir")
    flags.mark_flag_as_required("model_name_or_path")
    flags.mark_flag_as_required("model_type")
    app.run(main)