fine_tuning_bert.ipynb 17.7 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
2
3
4
5
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
Chen Chen's avatar
Chen Chen committed
6
7
        "colab_type": "text",
        "id": "YN2ACivEPxgD"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
8
9
10
11
12
13
14
15
16
17
18
19
      },
      "source": [
        "## How-to Guide: Using a PIP package for fine-tuning a BERT model\n",
        "\n",
        "Authors: [Chen Chen](https://github.com/chenGitHuber), [Claire Yao](https://github.com/claireyao-fen)\n",
        "\n",
        "In this example, we will work through fine-tuning a BERT model using the tensorflow-models PIP package."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
Chen Chen's avatar
Chen Chen committed
20
21
        "colab_type": "text",
        "id": "T7BBEc1-RNCQ"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
      },
      "source": [
        "## License\n",
        "\n",
        "Copyright 2020 The TensorFlow Authors. All Rights Reserved.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    http://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
Chen Chen's avatar
Chen Chen committed
44
45
        "colab_type": "text",
        "id": "Pf6xzoKjywY_"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
46
47
48
49
50
51
52
53
54
55
      },
      "source": [
        "## Learning objectives\n",
        "\n",
        "In this Colab notebook, you will learn how to fine-tune a BERT model using the TensorFlow Model Garden PIP package."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
Chen Chen's avatar
Chen Chen committed
56
57
        "colab_type": "text",
        "id": "YHkmV89jRWkS"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
58
59
60
61
62
63
64
65
66
67
68
69
      },
      "source": [
        "## Enable the GPU acceleration\n",
        "Please enable GPU for better performance.\n",
        "*   Navigate to Edit.\n",
        "*   Find Notebook settings.\n",
        "*   Select GPU from the \"Hardware Accelerator\" drop-down list, save it."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
Chen Chen's avatar
Chen Chen committed
70
71
        "colab_type": "text",
        "id": "s2d9S2CSSO1z"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
72
73
74
75
76
77
78
79
      },
      "source": [
        "##Install and import"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
Chen Chen's avatar
Chen Chen committed
80
81
        "colab_type": "text",
        "id": "fsACVQpVSifi"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
82
83
84
85
      },
      "source": [
        "### Install the TensorFlow Model Garden pip package\n",
        "\n",
Chen Chen's avatar
Chen Chen committed
86
        "*  tf-models-nightly is the nightly Model Garden package created daily automatically.\n",
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
87
88
89
90
91
        "*  pip will install all models and dependencies automatically."
      ]
    },
    {
      "cell_type": "code",
Chen Chen's avatar
Chen Chen committed
92
      "execution_count": 0,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
93
      "metadata": {
Chen Chen's avatar
Chen Chen committed
94
        "colab": {},
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
95
        "colab_type": "code",
Chen Chen's avatar
Chen Chen committed
96
        "id": "NvNr2svBM-p3"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
97
      },
Chen Chen's avatar
Chen Chen committed
98
      "outputs": [],
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
99
      "source": [
Chen Chen's avatar
Chen Chen committed
100
        "!pip install tf-models-nightly"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
101
102
103
104
105
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
Chen Chen's avatar
Chen Chen committed
106
107
        "colab_type": "text",
        "id": "U-7qPCjWUAyy"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
108
109
110
111
112
113
114
      },
      "source": [
        "### Import Tensorflow and other libraries"
      ]
    },
    {
      "cell_type": "code",
Chen Chen's avatar
Chen Chen committed
115
      "execution_count": 0,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
116
      "metadata": {
Chen Chen's avatar
Chen Chen committed
117
        "colab": {},
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
118
        "colab_type": "code",
Chen Chen's avatar
Chen Chen committed
119
        "id": "lXsXev5MNr20"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
120
      },
Chen Chen's avatar
Chen Chen committed
121
      "outputs": [],
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
122
      "source": [
Chen Chen's avatar
Chen Chen committed
123
        "import os\n",
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
124
        "\n",
Chen Chen's avatar
Chen Chen committed
125
126
127
128
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "\n",
        "from official.modeling import tf_utils\n",
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
129
130
131
132
        "from official.nlp import optimization\n",
        "from official.nlp.bert import configs as bert_configs\n",
        "from official.nlp.bert import tokenization\n",
        "from official.nlp.data import classifier_data_lib\n",
Chen Chen's avatar
Chen Chen committed
133
134
135
136
        "from official.nlp.modeling import losses\n",
        "from official.nlp.modeling import models\n",
        "from official.nlp.modeling import networks"
      ]
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
137
138
139
140
    },
    {
      "cell_type": "markdown",
      "metadata": {
Chen Chen's avatar
Chen Chen committed
141
142
        "colab_type": "text",
        "id": "C2drjD7OVCmh"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
143
144
      },
      "source": [
Chen Chen's avatar
Chen Chen committed
145
        "## Preprocess the raw data and output tf.record files"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
146
147
148
149
150
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
Chen Chen's avatar
Chen Chen committed
151
152
        "colab_type": "text",
        "id": "qfjcKj5FYQOp"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
153
154
155
156
      },
      "source": [
        "### Introduction of dataset\n",
        "\n",
Chen Chen's avatar
Chen Chen committed
157
        "The Microsoft Research Paraphrase Corpus (Dolan \u0026 Brockett, 2005) is a corpus of sentence pairs automatically extracted from online news sources, with human annotations for whether the sentences in the pair are semantically equivalent.\n",
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
158
159
160
161
162
163
164
165
166
167
168
        "\n",
        "*   Number of labels: 2.\n",
        "*   Size of training dataset: 3668.\n",
        "*   Size of evaluation dataset: 408.\n",
        "*   Maximum sequence length of training and evaluation dataset: 128.\n",
        "*   Please refer here for details: https://www.tensorflow.org/datasets/catalog/glue#gluemrpc"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
Chen Chen's avatar
Chen Chen committed
169
170
        "colab_type": "text",
        "id": "28DvUhC1YUiB"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
171
172
173
174
175
176
177
178
179
180
      },
      "source": [
        "### Get dataset from TensorFlow Datasets (TFDS)\n",
        "\n",
        "For example, we used the GLUE MRPC dataset from TFDS: https://www.tensorflow.org/datasets/catalog/glue#gluemrpc."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
Chen Chen's avatar
Chen Chen committed
181
182
        "colab_type": "text",
        "id": "4PhRLWh9jaXp"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
183
184
185
186
187
188
189
190
      },
      "source": [
        "### Preprocess the data and write to TensorFlow record file\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
Chen Chen's avatar
Chen Chen committed
191
      "execution_count": 0,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
192
      "metadata": {
Chen Chen's avatar
Chen Chen committed
193
        "colab": {},
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
194
        "colab_type": "code",
Chen Chen's avatar
Chen Chen committed
195
        "id": "FhcMdzsrjWzG"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
196
      },
Chen Chen's avatar
Chen Chen committed
197
      "outputs": [],
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
198
199
200
201
202
      "source": [
        "gs_folder_bert = \"gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12\"\n",
        "\n",
        "# Set up tokenizer to generate Tensorflow dataset\n",
        "tokenizer = tokenization.FullTokenizer(\n",
Chen Chen's avatar
Chen Chen committed
203
        "    vocab_file=os.path.join(gs_folder_bert, \"vocab.txt\"), do_lower_case=True)\n",
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
204
205
206
207
        "\n",
        "# Set up processor to generate Tensorflow dataset\n",
        "processor = classifier_data_lib.TfdsProcessor(\n",
        "    tfds_params=\"dataset=glue/mrpc,text_key=sentence1,text_b_key=sentence2\",\n",
Chen Chen's avatar
Chen Chen committed
208
209
210
211
212
        "    process_text_fn=tokenization.convert_to_unicode)\n",
        "\n",
        "# Set up output of training and evaluation Tensorflow dataset\n",
        "train_data_output_path=\"./mrpc_train.tf_record\"\n",
        "eval_data_output_path=\"./mrpc_eval.tf_record\"\n",
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
213
214
215
        "\n",
        "# Generate and save training data into a tf record file\n",
        "input_meta_data = classifier_data_lib.generate_tf_record_from_data_file(\n",
Chen Chen's avatar
Chen Chen committed
216
217
218
219
220
        "    processor=processor,\n",
        "    data_dir=None,  # It is `None` because data is from tfds, not local dir.\n",
        "    tokenizer=tokenizer,\n",
        "    train_data_output_path=train_data_output_path,\n",
        "    eval_data_output_path=eval_data_output_path,\n",
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
221
222
223
224
225
226
        "    max_seq_length=128)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
Chen Chen's avatar
Chen Chen committed
227
228
        "colab_type": "text",
        "id": "dbJ76vSJj77j"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
229
230
      },
      "source": [
Chen Chen's avatar
Chen Chen committed
231
        "### Create tf.dataset for training and evaluation\n"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
232
233
234
235
      ]
    },
    {
      "cell_type": "code",
Chen Chen's avatar
Chen Chen committed
236
      "execution_count": 0,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
237
      "metadata": {
Chen Chen's avatar
Chen Chen committed
238
        "colab": {},
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
239
        "colab_type": "code",
Chen Chen's avatar
Chen Chen committed
240
        "id": "gCvaLLAxPuMc"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
241
      },
Chen Chen's avatar
Chen Chen committed
242
      "outputs": [],
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
243
      "source": [
Chen Chen's avatar
Chen Chen committed
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
        "def create_classifier_dataset(file_path, seq_length, batch_size, is_training):\n",
        "  \"\"\"Creates input dataset from (tf)records files for train/eval.\"\"\"\n",
        "  dataset = tf.data.TFRecordDataset(file_path)\n",
        "  if is_training:\n",
        "    dataset = dataset.shuffle(100)\n",
        "    dataset = dataset.repeat()\n",
        "\n",
        "  def decode_record(record):\n",
        "    name_to_features = {\n",
        "      'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),\n",
        "      'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),\n",
        "      'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),\n",
        "      'label_ids': tf.io.FixedLenFeature([], tf.int64),\n",
        "    }\n",
        "    return tf.io.parse_single_example(record, name_to_features)\n",
        "\n",
        "  def _select_data_from_record(record):\n",
        "    x = {\n",
        "        'input_word_ids': record['input_ids'],\n",
        "        'input_mask': record['input_mask'],\n",
        "        'input_type_ids': record['segment_ids']\n",
        "    }\n",
        "    y = record['label_ids']\n",
        "    return (x, y)\n",
        "\n",
        "  dataset = dataset.map(decode_record,\n",
        "                        num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
        "  dataset = dataset.map(\n",
        "      _select_data_from_record,\n",
        "      num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
        "  dataset = dataset.batch(batch_size, drop_remainder=is_training)\n",
        "  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)\n",
        "  return dataset\n",
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
277
278
279
280
281
282
        "\n",
        "# Set up batch sizes\n",
        "batch_size = 32\n",
        "eval_batch_size = 32\n",
        "\n",
        "# Return Tensorflow dataset\n",
Chen Chen's avatar
Chen Chen committed
283
284
285
286
287
288
289
290
291
292
293
294
        "training_dataset = create_classifier_dataset(\n",
        "    train_data_output_path,\n",
        "    input_meta_data['max_seq_length'],\n",
        "    batch_size,\n",
        "    is_training=True)\n",
        "\n",
        "evaluation_dataset = create_classifier_dataset(\n",
        "    eval_data_output_path,\n",
        "    input_meta_data['max_seq_length'],\n",
        "    eval_batch_size,\n",
        "    is_training=False)\n"
      ]
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
295
296
297
298
    },
    {
      "cell_type": "markdown",
      "metadata": {
Chen Chen's avatar
Chen Chen committed
299
300
        "colab_type": "text",
        "id": "Efrj3Cn1kLAp"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
301
302
303
304
305
306
307
308
      },
      "source": [
        "## Create, compile and train the model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
Chen Chen's avatar
Chen Chen committed
309
310
        "colab_type": "text",
        "id": "96ldxDSwkVkj"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
311
312
313
314
315
316
317
318
319
      },
      "source": [
        "### Construct a Bert Model\n",
        "\n",
        "Here, a Bert Model is constructed from the json file with parameters. The bert_config defines the core Bert Model, which is a Keras model to predict the outputs of *num_classes* from the inputs with maximum sequence length *max_seq_length*. "
      ]
    },
    {
      "cell_type": "code",
Chen Chen's avatar
Chen Chen committed
320
      "execution_count": 0,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
321
      "metadata": {
Chen Chen's avatar
Chen Chen committed
322
        "colab": {},
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
323
        "colab_type": "code",
Chen Chen's avatar
Chen Chen committed
324
        "id": "Qgajw8WPYzJZ"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
325
      },
Chen Chen's avatar
Chen Chen committed
326
      "outputs": [],
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
327
      "source": [
Chen Chen's avatar
Chen Chen committed
328
        "bert_config_file = os.path.join(gs_folder_bert, \"bert_config.json\")\n",
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
329
        "bert_config = bert_configs.BertConfig.from_json_file(bert_config_file)\n",
Chen Chen's avatar
Chen Chen committed
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
        "\n",
        "bert_encoder = networks.TransformerEncoder(vocab_size=bert_config.vocab_size,\n",
        "      hidden_size=bert_config.hidden_size,\n",
        "      num_layers=bert_config.num_hidden_layers,\n",
        "      num_attention_heads=bert_config.num_attention_heads,\n",
        "      intermediate_size=bert_config.intermediate_size,\n",
        "      activation=tf_utils.get_activation(bert_config.hidden_act),\n",
        "      dropout_rate=bert_config.hidden_dropout_prob,\n",
        "      attention_dropout_rate=bert_config.attention_probs_dropout_prob,\n",
        "      sequence_length=input_meta_data['max_seq_length'],\n",
        "      max_sequence_length=bert_config.max_position_embeddings,\n",
        "      type_vocab_size=bert_config.type_vocab_size,\n",
        "      embedding_width=bert_config.embedding_size,\n",
        "      initializer=tf.keras.initializers.TruncatedNormal(\n",
        "          stddev=bert_config.initializer_range))\n",
        "\n",
        "classifier_model = models.BertClassifier(\n",
        "        bert_encoder,\n",
        "        num_classes=input_meta_data['num_labels'],\n",
        "        dropout_rate=bert_config.hidden_dropout_prob,\n",
        "        initializer=tf.keras.initializers.TruncatedNormal(\n",
        "          stddev=bert_config.initializer_range))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "pkSq1wbNXBaa"
      },
      "source": [
        "### Initialize the encoder from a pretrained model"
      ]
    },
    {
      "cell_type": "code",
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
366
      "execution_count": 0,
Chen Chen's avatar
Chen Chen committed
367
368
369
370
371
372
373
374
375
376
377
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "X6N9NEqfXJCx"
      },
      "outputs": [],
      "source": [
        "checkpoint = tf.train.Checkpoint(model=bert_encoder)\n",
        "checkpoint.restore(\n",
        "    os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()"
      ]
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
378
379
380
381
    },
    {
      "cell_type": "markdown",
      "metadata": {
Chen Chen's avatar
Chen Chen committed
382
383
        "colab_type": "text",
        "id": "115caFLMk-_l"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
384
385
      },
      "source": [
Chen Chen's avatar
Chen Chen committed
386
387
388
389
        "### Set up an optimizer for the model\n",
        "\n",
        "BERT model adopts the Adam optimizer with weight decay.\n",
        "It also employs a learning rate schedule that firstly warms up from 0 and then decays to 0."
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
390
391
392
393
      ]
    },
    {
      "cell_type": "code",
Chen Chen's avatar
Chen Chen committed
394
      "execution_count": 0,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
395
      "metadata": {
Chen Chen's avatar
Chen Chen committed
396
        "colab": {},
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
397
        "colab_type": "code",
Chen Chen's avatar
Chen Chen committed
398
        "id": "2Hf2rpRXk89N"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
399
      },
Chen Chen's avatar
Chen Chen committed
400
      "outputs": [],
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
401
402
403
404
405
406
407
408
      "source": [
        "# Set up epochs and steps\n",
        "epochs = 3\n",
        "train_data_size = input_meta_data['train_data_size']\n",
        "steps_per_epoch = int(train_data_size / batch_size)\n",
        "num_train_steps = steps_per_epoch * epochs\n",
        "warmup_steps = int(epochs * train_data_size * 0.1 / batch_size)\n",
        "\n",
Chen Chen's avatar
Chen Chen committed
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
        "# Create learning rate schedule that firstly warms up from 0 and they decy to 0.\n",
        "lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(\n",
        "      initial_learning_rate=2e-5,\n",
        "      decay_steps=num_train_steps,\n",
        "      end_learning_rate=0)\n",
        "lr_schedule = optimization.WarmUp(\n",
        "        initial_learning_rate=2e-5,\n",
        "        decay_schedule_fn=lr_schedule,\n",
        "        warmup_steps=warmup_steps)\n",
        "optimizer = optimization.AdamWeightDecay(\n",
        "        learning_rate=lr_schedule,\n",
        "        weight_decay_rate=0.01,\n",
        "        beta_1=0.9,\n",
        "        beta_2=0.999,\n",
        "        epsilon=1e-6,\n",
        "        exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'])"
      ]
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
426
427
428
429
    },
    {
      "cell_type": "markdown",
      "metadata": {
Chen Chen's avatar
Chen Chen committed
430
431
        "colab_type": "text",
        "id": "OTNcA0O0nSq9"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
432
433
      },
      "source": [
Chen Chen's avatar
Chen Chen committed
434
435
436
        "### Define metric_fn and loss_fn\n",
        "\n",
        "The metric is accuracy and we use sparse categorical cross-entropy as loss."
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
437
438
439
440
      ]
    },
    {
      "cell_type": "code",
Chen Chen's avatar
Chen Chen committed
441
      "execution_count": 0,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
442
      "metadata": {
Chen Chen's avatar
Chen Chen committed
443
        "colab": {},
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
444
        "colab_type": "code",
Chen Chen's avatar
Chen Chen committed
445
        "id": "ELHjRp87nVNH"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
446
      },
Chen Chen's avatar
Chen Chen committed
447
      "outputs": [],
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
448
449
450
      "source": [
        "def metric_fn():\n",
        "  return tf.keras.metrics.SparseCategoricalAccuracy(\n",
Chen Chen's avatar
Chen Chen committed
451
        "      'accuracy', dtype=tf.float32)\n",
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
452
        "\n",
Chen Chen's avatar
Chen Chen committed
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
        "def classification_loss_fn(labels, logits):\n",
        "  return losses.weighted_sparse_categorical_crossentropy_loss(\n",
        "    labels=labels, predictions=tf.nn.log_softmax(logits, axis=-1))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "78FEUOOEkoP0"
      },
      "source": [
        "### Compile and train the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "nzi8hjeTQTRs"
      },
      "outputs": [],
      "source": [
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
478
        "classifier_model.compile(optimizer=optimizer,\n",
Chen Chen's avatar
Chen Chen committed
479
480
        "                         loss=classification_loss_fn,\n",
        "                         metrics=[metric_fn()])\n",
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
481
482
483
484
485
        "classifier_model.fit(\n",
        "      x=training_dataset,\n",
        "      validation_data=evaluation_dataset,\n",
        "      steps_per_epoch=steps_per_epoch,\n",
        "      epochs=epochs,\n",
Chen Chen's avatar
Chen Chen committed
486
        "      validation_steps=int(input_meta_data['eval_data_size'] / eval_batch_size))"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
487
488
489
490
491
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
Chen Chen's avatar
Chen Chen committed
492
493
        "colab_type": "text",
        "id": "fVo_AnT0l26j"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
494
495
496
497
498
499
500
      },
      "source": [
        "### Save the model"
      ]
    },
    {
      "cell_type": "code",
Chen Chen's avatar
Chen Chen committed
501
      "execution_count": 0,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
502
      "metadata": {
Chen Chen's avatar
Chen Chen committed
503
        "colab": {},
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
504
        "colab_type": "code",
Chen Chen's avatar
Chen Chen committed
505
        "id": "Nl5x6nElZqkP"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
506
      },
Chen Chen's avatar
Chen Chen committed
507
      "outputs": [],
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
508
      "source": [
Chen Chen's avatar
Chen Chen committed
509
        "classifier_model.save('./saved_model', include_optimizer=False, save_format='tf')"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
510
511
512
513
514
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
Chen Chen's avatar
Chen Chen committed
515
516
        "colab_type": "text",
        "id": "nWsE6yeyfW00"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
517
518
      },
      "source": [
Chen Chen's avatar
Chen Chen committed
519
        "## Use the trained model to predict\n"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
520
521
522
523
      ]
    },
    {
      "cell_type": "code",
Chen Chen's avatar
Chen Chen committed
524
      "execution_count": 0,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
525
      "metadata": {
Chen Chen's avatar
Chen Chen committed
526
        "colab": {},
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
527
        "colab_type": "code",
Chen Chen's avatar
Chen Chen committed
528
        "id": "vz7YJY2QYAjP"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
529
      },
Chen Chen's avatar
Chen Chen committed
530
      "outputs": [],
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
531
      "source": [
Chen Chen's avatar
Chen Chen committed
532
533
534
        "eval_predictions = classifier_model.predict(evaluation_dataset)\n",
        "for prediction in eval_predictions:\n",
        "  print(\"Predicted label id: %s\" % np.argmax(prediction))"
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
535
536
      ]
    }
Chen Chen's avatar
Chen Chen committed
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "How-to Guide: Using a PIP package for fine-tuning a BERT model",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}