run_tf_ner.py 26 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# coding=utf-8
import datetime
import os
import math
import glob
import re
import tensorflow as tf
import collections
import numpy as np
from seqeval import metrics
import _pickle as pickle
from absl import logging
from transformers import TF2_WEIGHTS_NAME, BertConfig, BertTokenizer, TFBertForTokenClassification
from transformers import RobertaConfig, RobertaTokenizer, TFRobertaForTokenClassification
from transformers import DistilBertConfig, DistilBertTokenizer, TFDistilBertForTokenClassification
from transformers import create_optimizer, GradientAccumulator
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
from fastprogress import master_bar, progress_bar
from absl import flags
from absl import app


ALL_MODELS = sum(
24
25
    (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)), ()
)
26
27
28
29

MODEL_CLASSES = {
    "bert": (BertConfig, TFBertForTokenClassification, BertTokenizer),
    "roberta": (RobertaConfig, TFRobertaForTokenClassification, RobertaTokenizer),
30
    "distilbert": (DistilBertConfig, TFDistilBertForTokenClassification, DistilBertTokenizer),
31
32
33
34
}


flags.DEFINE_string(
35
36
    "data_dir", None, "The input data dir. Should contain the .conll files (or other data files) " "for the task."
)
37

38
flags.DEFINE_string("model_type", None, "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
39
40

flags.DEFINE_string(
41
42
43
44
    "model_name_or_path",
    None,
    "Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
)
45

46
flags.DEFINE_string("output_dir", None, "The output directory where the model checkpoints will be written.")
47
48

flags.DEFINE_string(
49
50
    "labels", "", "Path to a file containing all labels. If not specified, CoNLL-2003 labels are used."
)
51

52
flags.DEFINE_string("config_name", "", "Pretrained config name or path if not the same as model_name")
53

54
flags.DEFINE_string("tokenizer_name", "", "Pretrained tokenizer name or path if not the same as model_name")
55

56
flags.DEFINE_string("cache_dir", "", "Where do you want to store the pre-trained models downloaded from s3")
57
58

flags.DEFINE_integer(
59
60
    "max_seq_length",
    128,
61
62
    "The maximum total input sentence length after tokenization. "
    "Sequences longer than this will be truncated, sequences shorter "
63
64
    "will be padded.",
)
65
66

flags.DEFINE_string(
67
68
    "tpu",
    None,
69
70
    "The Cloud TPU to use for training. This should be either the name "
    "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
71
72
    "url.",
)
73

74
flags.DEFINE_integer("num_tpu_cores", 8, "Total number of TPU cores to use.")
75

76
flags.DEFINE_boolean("do_train", False, "Whether to run training.")
77

78
flags.DEFINE_boolean("do_eval", False, "Whether to run eval on the dev set.")
79

80
flags.DEFINE_boolean("do_predict", False, "Whether to run predictions on the test set.")
81
82

flags.DEFINE_boolean(
83
84
    "evaluate_during_training", False, "Whether to run evaluation during training at each logging step."
)
85

86
flags.DEFINE_boolean("do_lower_case", False, "Set this flag if you are using an uncased model.")
87

88
flags.DEFINE_integer("per_device_train_batch_size", 8, "Batch size per GPU/CPU/TPU for training.")
89

90
flags.DEFINE_integer("per_device_eval_batch_size", 8, "Batch size per GPU/CPU/TPU for evaluation.")
91
92

flags.DEFINE_integer(
93
94
    "gradient_accumulation_steps", 1, "Number of updates steps to accumulate before performing a backward/update pass."
)
95

96
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
97

98
flags.DEFINE_float("weight_decay", 0.0, "Weight decay if we apply some.")
99

100
flags.DEFINE_float("adam_epsilon", 1e-8, "Epsilon for Adam optimizer.")
101

102
flags.DEFINE_float("max_grad_norm", 1.0, "Max gradient norm.")
103

104
flags.DEFINE_integer("num_train_epochs", 3, "Total number of training epochs to perform.")
105
106

flags.DEFINE_integer(
107
108
    "max_steps", -1, "If > 0: set total number of training steps to perform. Override num_train_epochs."
)
109

110
flags.DEFINE_integer("warmup_steps", 0, "Linear warmup over warmup_steps.")
111

112
flags.DEFINE_integer("logging_steps", 50, "Log every X updates steps.")
113

114
flags.DEFINE_integer("save_steps", 50, "Save checkpoint every X updates steps.")
115
116

flags.DEFINE_boolean(
117
118
119
120
    "eval_all_checkpoints",
    False,
    "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
)
121

122
flags.DEFINE_boolean("no_cuda", False, "Avoid using CUDA when available")
123

124
flags.DEFINE_boolean("overwrite_output_dir", False, "Overwrite the content of the output directory")
125

126
flags.DEFINE_boolean("overwrite_cache", False, "Overwrite the cached training and evaluation sets")
127

128
flags.DEFINE_integer("seed", 42, "random seed for initialization")
129

130
flags.DEFINE_boolean("fp16", False, "Whether to use 16-bit (mixed) precision instead of 32-bit")
131
132

flags.DEFINE_string(
133
134
    "gpus",
    "0",
135
    "Comma separated list of gpus devices. If only one, switch to single "
136
137
    "gpu strategy, if None takes all the gpus available.",
)
138
139


140
141
142
143
144
145
def train(
    args, strategy, train_dataset, tokenizer, model, num_train_examples, labels, train_batch_size, pad_token_label_id
):
    if args["max_steps"] > 0:
        num_train_steps = args["max_steps"] * args["gradient_accumulation_steps"]
        args["num_train_epochs"] = 1
146
    else:
147
148
149
150
151
        num_train_steps = (
            math.ceil(num_train_examples / train_batch_size)
            // args["gradient_accumulation_steps"]
            * args["num_train_epochs"]
        )
152
153
154
155
156

    writer = tf.summary.create_file_writer("/tmp/mylogs")

    with strategy.scope():
        loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
157
        optimizer = create_optimizer(args["learning_rate"], num_train_steps, args["warmup_steps"])
158

159
160
        if args["fp16"]:
            optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(optimizer, "dynamic")
161

162
        loss_metric = tf.keras.metrics.Mean(name="loss", dtype=tf.float32)
163
        gradient_accumulator = GradientAccumulator()
164

165
166
    logging.info("***** Running training *****")
    logging.info("  Num examples = %d", num_train_examples)
167
168
169
170
171
172
173
    logging.info("  Num Epochs = %d", args["num_train_epochs"])
    logging.info("  Instantaneous batch size per device = %d", args["per_device_train_batch_size"])
    logging.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        train_batch_size * args["gradient_accumulation_steps"],
    )
    logging.info("  Gradient Accumulation steps = %d", args["gradient_accumulation_steps"])
174
175
176
177
178
179
180
181
182
183
    logging.info("  Total training steps = %d", num_train_steps)

    model.summary()

    @tf.function
    def apply_gradients():
        grads_and_vars = []

        for gradient, variable in zip(gradient_accumulator.gradients, model.trainable_variables):
            if gradient is not None:
184
                scaled_gradient = gradient / (args["n_device"] * args["gradient_accumulation_steps"])
185
186
187
188
                grads_and_vars.append((scaled_gradient, variable))
            else:
                grads_and_vars.append((gradient, variable))

189
        optimizer.apply_gradients(grads_and_vars, args["max_grad_norm"])
190
191
192
193
194
        gradient_accumulator.reset()

    @tf.function
    def train_step(train_features, train_labels):
        def step_fn(train_features, train_labels):
195
            inputs = {"attention_mask": train_features["input_mask"], "training": True}
196

197
198
199
200
            if args["model_type"] != "distilbert":
                inputs["token_type_ids"] = (
                    train_features["segment_ids"] if args["model_type"] in ["bert", "xlnet"] else None
                )
201
202

            with tf.GradientTape() as tape:
203
                logits = model(train_features["input_ids"], **inputs)[0]
204
                logits = tf.reshape(logits, (-1, len(labels) + 1))
205
                active_loss = tf.reshape(train_features["input_mask"], (-1,))
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
                active_logits = tf.boolean_mask(logits, active_loss)
                train_labels = tf.reshape(train_labels, (-1,))
                active_labels = tf.boolean_mask(train_labels, active_loss)
                cross_entropy = loss_fct(active_labels, active_logits)
                loss = tf.reduce_sum(cross_entropy) * (1.0 / train_batch_size)
                grads = tape.gradient(loss, model.trainable_variables)

                gradient_accumulator(grads)

            return cross_entropy

        per_example_losses = strategy.experimental_run_v2(step_fn, args=(train_features, train_labels))
        mean_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_example_losses, axis=0)

        return mean_loss

    current_time = datetime.datetime.now()
223
    train_iterator = master_bar(range(args["num_train_epochs"]))
224
225
226
227
    global_step = 0
    logging_loss = 0.0

    for epoch in train_iterator:
228
229
230
        epoch_iterator = progress_bar(
            train_dataset, total=num_train_steps, parent=train_iterator, display=args["n_device"] > 1
        )
231
232
233
234
235
236
        step = 1

        with strategy.scope():
            for train_features, train_labels in epoch_iterator:
                loss = train_step(train_features, train_labels)

237
                if step % args["gradient_accumulation_steps"] == 0:
238
239
240
241
242
243
                    strategy.experimental_run_v2(apply_gradients)

                    loss_metric(loss)

                    global_step += 1

244
                    if args["logging_steps"] > 0 and global_step % args["logging_steps"] == 0:
245
                        # Log metrics
246
247
248
249
250
251
                        if (
                            args["n_device"] == 1 and args["evaluate_during_training"]
                        ):  # Only evaluate when single GPU otherwise metrics may not average well
                            y_true, y_pred, eval_loss = evaluate(
                                args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev"
                            )
252
                            report = metrics.classification_report(y_true, y_pred, digits=4)
253

254
255
                            logging.info("Eval at step " + str(global_step) + "\n" + report)
                            logging.info("eval_loss: " + str(eval_loss))
256

257
258
259
260
261
262
263
264
265
                            precision = metrics.precision_score(y_true, y_pred)
                            recall = metrics.recall_score(y_true, y_pred)
                            f1 = metrics.f1_score(y_true, y_pred)

                            with writer.as_default():
                                tf.summary.scalar("eval_loss", eval_loss, global_step)
                                tf.summary.scalar("precision", precision, global_step)
                                tf.summary.scalar("recall", recall, global_step)
                                tf.summary.scalar("f1", f1, global_step)
266

267
268
269
270
271
                        lr = optimizer.learning_rate
                        learning_rate = lr(step)

                        with writer.as_default():
                            tf.summary.scalar("lr", learning_rate, global_step)
272
273
274
275
                            tf.summary.scalar(
                                "loss", (loss_metric.result() - logging_loss) / args["logging_steps"], global_step
                            )

276
277
278
279
280
                        logging_loss = loss_metric.result()

                    with writer.as_default():
                        tf.summary.scalar("loss", loss_metric.result(), step=step)

281
                    if args["save_steps"] > 0 and global_step % args["save_steps"] == 0:
282
                        # Save model checkpoint
283
                        output_dir = os.path.join(args["output_dir"], "checkpoint-{}".format(global_step))
284
285
286

                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
287

288
289
                        model.save_pretrained(output_dir)
                        logging.info("Saving model checkpoint to %s", output_dir)
290
291

                train_iterator.child.comment = f"loss : {loss_metric.result()}"
292
293
                step += 1

294
        train_iterator.write(f"loss epoch {epoch + 1}: {loss_metric.result()}")
295
296
297
298
299
300
301

        loss_metric.reset_states()

    logging.info("  Training took time = {}".format(datetime.datetime.now() - current_time))


def evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode):
302
303
304
305
    eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"]
    eval_dataset, size = load_and_cache_examples(
        args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode=mode
    )
306
307
308
309
    eval_dataset = strategy.experimental_distribute_dataset(eval_dataset)
    preds = None
    num_eval_steps = math.ceil(size / eval_batch_size)
    master = master_bar(range(1))
310
    eval_iterator = progress_bar(eval_dataset, total=num_eval_steps, parent=master, display=args["n_device"] > 1)
311
312
313
314
315
316
317
318
    loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
    loss = 0.0

    logging.info("***** Running evaluation *****")
    logging.info("  Num examples = %d", size)
    logging.info("  Batch size = %d", eval_batch_size)

    for eval_features, eval_labels in eval_iterator:
319
        inputs = {"attention_mask": eval_features["input_mask"], "training": False}
320

321
322
323
324
        if args["model_type"] != "distilbert":
            inputs["token_type_ids"] = (
                eval_features["segment_ids"] if args["model_type"] in ["bert", "xlnet"] else None
            )
325
326

        with strategy.scope():
327
            logits = model(eval_features["input_ids"], **inputs)[0]
328
            tmp_logits = tf.reshape(logits, (-1, len(labels) + 1))
329
            active_loss = tf.reshape(eval_features["input_mask"], (-1,))
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
            active_logits = tf.boolean_mask(tmp_logits, active_loss)
            tmp_eval_labels = tf.reshape(eval_labels, (-1,))
            active_labels = tf.boolean_mask(tmp_eval_labels, active_loss)
            cross_entropy = loss_fct(active_labels, active_logits)
            loss += tf.reduce_sum(cross_entropy) * (1.0 / eval_batch_size)

        if preds is None:
            preds = logits.numpy()
            label_ids = eval_labels.numpy()
        else:
            preds = np.append(preds, logits.numpy(), axis=0)
            label_ids = np.append(label_ids, eval_labels.numpy(), axis=0)

    preds = np.argmax(preds, axis=2)
    y_pred = [[] for _ in range(label_ids.shape[0])]
    y_true = [[] for _ in range(label_ids.shape[0])]
    loss = loss / num_eval_steps

    for i in range(label_ids.shape[0]):
        for j in range(label_ids.shape[1]):
            if label_ids[i, j] != pad_token_label_id:
                y_pred[i].append(labels[preds[i, j] - 1])
                y_true[i].append(labels[label_ids[i, j] - 1])

    return y_true, y_pred, loss.numpy()


def load_cache(cached_file, max_seq_length):
    name_to_features = {
        "input_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
        "input_mask": tf.io.FixedLenFeature([max_seq_length], tf.int64),
        "segment_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
        "label_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
    }

    def _decode_record(record):
        example = tf.io.parse_single_example(record, name_to_features)
        features = {}
368
369
370
        features["input_ids"] = example["input_ids"]
        features["input_mask"] = example["input_mask"]
        features["segment_ids"] = example["segment_ids"]
371

372
        return features, example["label_ids"]
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405

    d = tf.data.TFRecordDataset(cached_file)
    d = d.map(_decode_record, num_parallel_calls=4)
    count = d.reduce(0, lambda x, _: x + 1)

    return d, count.numpy()


def save_cache(features, cached_features_file):
    writer = tf.io.TFRecordWriter(cached_features_file)

    for (ex_index, feature) in enumerate(features):
        if ex_index % 5000 == 0:
            logging.info("Writing example %d of %d" % (ex_index, len(features)))

        def create_int_feature(values):
            f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
            return f

        record_feature = collections.OrderedDict()
        record_feature["input_ids"] = create_int_feature(feature.input_ids)
        record_feature["input_mask"] = create_int_feature(feature.input_mask)
        record_feature["segment_ids"] = create_int_feature(feature.segment_ids)
        record_feature["label_ids"] = create_int_feature(feature.label_ids)

        tf_example = tf.train.Example(features=tf.train.Features(feature=record_feature))

        writer.write(tf_example.SerializeToString())

    writer.close()


def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, batch_size, mode):
406
    drop_remainder = True if args["tpu"] or mode == "train" else False
407
408

    # Load data features from cache or dataset file
409
410
411
412
413
414
415
    cached_features_file = os.path.join(
        args["data_dir"],
        "cached_{}_{}_{}.tf_record".format(
            mode, list(filter(None, args["model_name_or_path"].split("/"))).pop(), str(args["max_seq_length"])
        ),
    )
    if os.path.exists(cached_features_file) and not args["overwrite_cache"]:
416
        logging.info("Loading features from cached file %s", cached_features_file)
417
        dataset, size = load_cache(cached_features_file, args["max_seq_length"])
418
    else:
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
        logging.info("Creating features from dataset file at %s", args["data_dir"])
        examples = read_examples_from_file(args["data_dir"], mode)
        features = convert_examples_to_features(
            examples,
            labels,
            args["max_seq_length"],
            tokenizer,
            cls_token_at_end=bool(args["model_type"] in ["xlnet"]),
            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            cls_token_segment_id=2 if args["model_type"] in ["xlnet"] else 0,
            sep_token=tokenizer.sep_token,
            sep_token_extra=bool(args["model_type"] in ["roberta"]),
            # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
            pad_on_left=bool(args["model_type"] in ["xlnet"]),
            # pad on the left for xlnet
            pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
            pad_token_segment_id=4 if args["model_type"] in ["xlnet"] else 0,
            pad_token_label_id=pad_token_label_id,
        )
439
440
        logging.info("Saving features into cached file %s", cached_features_file)
        save_cache(features, cached_features_file)
441
        dataset, size = load_cache(cached_features_file, args["max_seq_length"])
442

443
    if mode == "train":
444
        dataset = dataset.repeat()
445
        dataset = dataset.shuffle(buffer_size=8192, seed=args["seed"])
446
447
448
449
450
451
452
453
454
455
456

    dataset = dataset.batch(batch_size, drop_remainder)
    dataset = dataset.prefetch(buffer_size=batch_size)

    return dataset, size


def main(_):
    logging.set_verbosity(logging.INFO)
    args = flags.FLAGS.flag_values_dict()

457
458
459
460
461
462
    if (
        os.path.exists(args["output_dir"])
        and os.listdir(args["output_dir"])
        and args["do_train"]
        and not args["overwrite_output_dir"]
    ):
463
464
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
465
466
467
                args["output_dir"]
            )
        )
468

469
    if args["fp16"]:
470
471
        tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})

472
473
    if args["tpu"]:
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=args["tpu"])
474
475
476
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.experimental.TPUStrategy(resolver)
477
478
479
480
481
482
        args["n_device"] = args["num_tpu_cores"]
    elif len(args["gpus"].split(",")) > 1:
        args["n_device"] = len([f"/gpu:{gpu}" for gpu in args["gpus"].split(",")])
        strategy = tf.distribute.MirroredStrategy(devices=[f"/gpu:{gpu}" for gpu in args["gpus"].split(",")])
    elif args["no_cuda"]:
        args["n_device"] = 1
483
484
        strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
    else:
485
486
        args["n_device"] = len(args["gpus"].split(","))
        strategy = tf.distribute.OneDeviceStrategy(device="/gpu:" + args["gpus"].split(",")[0])
487

488
489
490
491
492
493
    logging.warning(
        "n_device: %s, distributed training: %s, 16-bits training: %s",
        args["n_device"],
        bool(args["n_device"] > 1),
        args["fp16"],
    )
494

495
    labels = get_labels(args["labels"])
496
497
    num_labels = len(labels) + 1
    pad_token_label_id = 0
498
499
500
501
502
503
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args["model_type"]]
    config = config_class.from_pretrained(
        args["config_name"] if args["config_name"] else args["model_name_or_path"],
        num_labels=num_labels,
        cache_dir=args["cache_dir"] if args["cache_dir"] else None,
    )
504
505
506
507

    logging.info("Training/evaluation parameters %s", args)

    # Training
508
509
510
511
512
513
    if args["do_train"]:
        tokenizer = tokenizer_class.from_pretrained(
            args["tokenizer_name"] if args["tokenizer_name"] else args["model_name_or_path"],
            do_lower_case=args["do_lower_case"],
            cache_dir=args["cache_dir"] if args["cache_dir"] else None,
        )
514
515

        with strategy.scope():
516
517
518
519
520
521
            model = model_class.from_pretrained(
                args["model_name_or_path"],
                from_pt=bool(".bin" in args["model_name_or_path"]),
                config=config,
                cache_dir=args["cache_dir"] if args["cache_dir"] else None,
            )
522
523
            model.layers[-1].activation = tf.keras.activations.softmax

524
525
526
527
        train_batch_size = args["per_device_train_batch_size"] * args["n_device"]
        train_dataset, num_train_examples = load_and_cache_examples(
            args, tokenizer, labels, pad_token_label_id, train_batch_size, mode="train"
        )
528
        train_dataset = strategy.experimental_distribute_dataset(train_dataset)
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
        train(
            args,
            strategy,
            train_dataset,
            tokenizer,
            model,
            num_train_examples,
            labels,
            train_batch_size,
            pad_token_label_id,
        )

        if not os.path.exists(args["output_dir"]):
            os.makedirs(args["output_dir"])

        logging.info("Saving model to %s", args["output_dir"])

        model.save_pretrained(args["output_dir"])
        tokenizer.save_pretrained(args["output_dir"])
548
549

    # Evaluation
550
551
    if args["do_eval"]:
        tokenizer = tokenizer_class.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
552
553
554
        checkpoints = []
        results = []

555
556
557
558
559
560
561
562
563
        if args["eval_all_checkpoints"]:
            checkpoints = list(
                os.path.dirname(c)
                for c in sorted(
                    glob.glob(args["output_dir"] + "/**/" + TF2_WEIGHTS_NAME, recursive=True),
                    key=lambda f: int("".join(filter(str.isdigit, f)) or -1),
                )
            )

564
        logging.info("Evaluate the following checkpoints: %s", checkpoints)
565
566

        if len(checkpoints) == 0:
567
568
            checkpoints.append(args["output_dir"])

569
570
571
572
573
574
        for checkpoint in checkpoints:
            global_step = checkpoint.split("-")[-1] if re.match(".*checkpoint-[0-9]", checkpoint) else "final"

            with strategy.scope():
                model = model_class.from_pretrained(checkpoint)

575
576
577
            y_true, y_pred, eval_loss = evaluate(
                args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev"
            )
578
579
580
581
582
            report = metrics.classification_report(y_true, y_pred, digits=4)

            if global_step:
                results.append({global_step + "_report": report, global_step + "_loss": eval_loss})

583
584
        output_eval_file = os.path.join(args["output_dir"], "eval_results.txt")

585
586
587
588
589
590
591
592
593
594
595
596
597
598
        with tf.io.gfile.GFile(output_eval_file, "w") as writer:
            for res in results:
                for key, val in res.items():
                    if "loss" in key:
                        logging.info(key + " = " + str(val))
                        writer.write(key + " = " + str(val))
                        writer.write("\n")
                    else:
                        logging.info(key)
                        logging.info("\n" + report)
                        writer.write(key + "\n")
                        writer.write(report)
                        writer.write("\n")

599
600
601
602
603
604
605
    if args["do_predict"]:
        tokenizer = tokenizer_class.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
        model = model_class.from_pretrained(args["output_dir"])
        eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"]
        predict_dataset, _ = load_and_cache_examples(
            args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode="test"
        )
606
        y_true, y_pred, pred_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="test")
607
608
        output_test_results_file = os.path.join(args["output_dir"], "test_results.txt")
        output_test_predictions_file = os.path.join(args["output_dir"], "test_predictions.txt")
609
610
611
612
        report = metrics.classification_report(y_true, y_pred, digits=4)

        with tf.io.gfile.GFile(output_test_results_file, "w") as writer:
            report = metrics.classification_report(y_true, y_pred, digits=4)
613

614
            logging.info("\n" + report)
615

616
617
618
619
            writer.write(report)
            writer.write("\n\nloss = " + str(pred_loss))

        with tf.io.gfile.GFile(output_test_predictions_file, "w") as writer:
620
            with tf.io.gfile.GFile(os.path.join(args["data_dir"], "test.txt"), "r") as f:
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
                example_id = 0

                for line in f:
                    if line.startswith("-DOCSTART-") or line == "" or line == "\n":
                        writer.write(line)

                        if not y_pred[example_id]:
                            example_id += 1
                    elif y_pred[example_id]:
                        output_line = line.split()[0] + " " + y_pred[example_id].pop(0) + "\n"
                        writer.write(output_line)
                    else:
                        logging.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0])


if __name__ == "__main__":
    flags.mark_flag_as_required("data_dir")
    flags.mark_flag_as_required("output_dir")
    flags.mark_flag_as_required("model_name_or_path")
    flags.mark_flag_as_required("model_type")
    app.run(main)