" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/nlp/fine_tune_bert\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/nlp/fine_tune_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/nlp/fine_tune_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
"This tutorial demonstrates how to fine-tune a [Bidirectional Encoder Representations from Transformers (BERT)](https://arxiv.org/abs/1810.04805) (Devlin et al., 2018) model using [TensorFlow Model Garden](https://github.com/tensorflow/models).\n",
"\n",
"You can also find the pre-trained BERT model used in this tutorial on [TensorFlow Hub (TF Hub)](https://tensorflow.org/hub). For concrete examples of how to use the models from TF Hub, refer to the [Solve Glue tasks using BERT](https://www.tensorflow.org/text/tutorials/bert_glue) tutorial. If you're just trying to fine-tune a model, the TF Hub tutorial is a good starting point.\n",
"\n",
"On the other hand, if you're interested in deeper customization, follow this tutorial. It shows how to do a lot of things manually, so you can learn how you can customize the workflow from data preprocessing to training, exporting and saving the model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s2d9S2CSSO1z"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "69de3375e32a"
},
"source": [
"### Install pip packages"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fsACVQpVSifi"
},
"source": [
"Start by installing the TensorFlow Text and Model Garden pip packages.\n",
"\n",
"* `tf-models-official` is the TensorFlow Model Garden package. Note that it may not include the latest changes in the `tensorflow_models` GitHub repo. To include the latest changes, you may install `tf-models-nightly`, which is the nightly Model Garden package created daily automatically.\n",
"* pip will install all models and dependencies automatically."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sE6XUxLOf1s-"
},
"outputs": [],
"source": [
"!pip uninstall -y opencv-python"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yic2y7_o-BCC"
},
"outputs": [],
"source": [
"!pip install -q -U \"tensorflow-text==2.9.*\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NvNr2svBM-p3"
},
"outputs": [],
"source": [
"!pip install -q tf-models-official"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "U-7qPCjWUAyy"
},
"source": [
"### Import libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lXsXev5MNr20"
},
"outputs": [],
"source": [
"import os\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import tensorflow as tf\n",
"import tensorflow_models as tfm\n",
"import tensorflow_hub as hub\n",
"import tensorflow_datasets as tfds\n",
"tfds.disable_progress_bar()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mbanlzTvJBsz"
},
"source": [
"### Resources"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PpW0x8TpR8DT"
},
"source": [
"The following directory contains the BERT model's configuration, vocabulary, and a pre-trained checkpoint used in this tutorial:"
"This example uses the GLUE (General Language Understanding Evaluation) MRPC (Microsoft Research Paraphrase Corpus) [dataset from TensorFlow Datasets (TFDS)](https://www.tensorflow.org/datasets/catalog/glue#gluemrpc).\n",
"\n",
"This dataset is not set up such that it can be directly fed into the BERT model. The following section handles the necessary preprocessing."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "28DvUhC1YUiB"
},
"source": [
"### Get the dataset from TensorFlow Datasets\n",
"\n",
"The GLUE MRPC (Dolan and Brockett, 2005) dataset is a corpus of sentence pairs automatically extracted from online news sources, with human annotations for whether the sentences in the pair are semantically equivalent. It has the following attributes:\n",
"\n",
"* Number of labels: 2\n",
"* Size of training dataset: 3668\n",
"* Size of evaluation dataset: 408\n",
"* Maximum sequence length of training and evaluation dataset: 128\n",
"\n",
"Begin by loading the MRPC dataset from TFDS:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ijikx5OsH9AT"
},
"outputs": [],
"source": [
"batch_size=32\n",
"glue, info = tfds.load('glue/mrpc',\n",
" with_info=True,\n",
" batch_size=32)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QcMTJU4N7VX-"
},
"outputs": [],
"source": [
"glue"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZgBg2r2nYT-K"
},
"source": [
"The `info` object describes the dataset and its features:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IQrHxv7W7jH5"
},
"outputs": [],
"source": [
"info.features"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vhsVWYNxazz5"
},
"source": [
"The two classes are:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "n0gfc_VTayfQ"
},
"outputs": [],
"source": [
"info.features['label'].names"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "38zJcap6xkbC"
},
"source": [
"Here is one example from the training set:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xON_i6SkwApW"
},
"outputs": [],
"source": [
"example_batch = next(iter(glue['train']))\n",
"\n",
"for key, value in example_batch.items():\n",
" print(f\"{key:9s}: {value[0].numpy()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R9vEWgKA4SxV"
},
"source": [
"### Preprocess the data\n",
"\n",
"The keys `\"sentence1\"` and `\"sentence2\"` in the GLUE MRPC dataset contain two input sentences for each example.\n",
"\n",
"Because the BERT model from the Model Garden doesn't take raw text as input, two things need to happen first:\n",
"\n",
"1. The text needs to be _tokenized_ (split into word pieces) and converted to _indices_.\n",
"2. Then, the _indices_ need to be packed into the format that the model expects."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9fbTyfJpNr7x"
},
"source": [
"#### The BERT tokenizer"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wqeN54S61ZKQ"
},
"source": [
"To fine tune a pre-trained language model from the Model Garden, such as BERT, you need to make sure that you're using exactly the same tokenization, vocabulary, and index mapping as used during training.\n",
"\n",
"The following code rebuilds the tokenizer that was used by the base model using the Model Garden's `tfm.nlp.layers.FastWordpieceBertTokenizer` layer:"
"Learn more about the tokenization process in the [Subword tokenization](https://www.tensorflow.org/text/guide/subwords_tokenizer) and [Tokenizing with TensorFlow Text](https://www.tensorflow.org/text/guide/tokenizers) guides."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wd1b09OO5GJl"
},
"source": [
"#### Pack the inputs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "62UTWLQd9-LB"
},
"source": [
"TensorFlow Model Garden's BERT model doesn't just take the tokenized strings as input. It also expects these to be packed into a particular format. `tfm.nlp.layers.BertPackInputs` layer can handle the conversion from _a list of tokenized sentences_ to the input format expected by the Model Garden's BERT model.\n",
"\n",
"`tfm.nlp.layers.BertPackInputs` packs the two input sentences (per example in the MRCP dataset) concatenated together. This input is expected to start with a `[CLS]` \"This is a classification problem\" token, and each sentence should end with a `[SEP]` \"Separator\" token.\n",
"\n",
"Therefore, the `tfm.nlp.layers.BertPackInputs` layer's constructor takes the `tokenizer`'s special tokens as an argument. It also needs to know the indices of the tokenizer's special tokens."
"The mask allows the model to cleanly differentiate between the content and the padding. The mask has the same shape as the `input_word_ids`, and contains a `1` anywhere the `input_word_ids` is not padding."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zB7mW7DGK3rW"
},
"outputs": [],
"source": [
"plt.pcolormesh(example_inputs['input_mask'])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rxLenwAvCkBf"
},
"source": [
"The \"input type\" also has the same shape, but inside the non-padded region, contains a `0` or a `1` indicating which sentence the token is a part of."
"The configuration file defines the core BERT model from the Model Garden, which is a Keras model that predicts the outputs of `num_classes` from the inputs with maximum sequence length `max_seq_length`."
"Note: The pretrained `TransformerEncoder` is also available on [TensorFlow Hub](https://tensorflow.org/hub). Go to the [TF Hub appendix](#hub_bert) for details."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "115caFLMk-_l"
},
"source": [
"### Set up the optimizer\n",
"\n",
"BERT typically uses the Adam optimizer with weight decay—[AdamW](https://arxiv.org/abs/1711.05101) (`tf.keras.optimizers.experimental.AdamW`).\n",
"It also employs a learning rate schedule that first warms up from 0 and then decays to 0:"
"Often the goal of training a model is to _use_ it for something outside of the Python process that created it. You can do this by exporting the model using `tf.saved_model`. (Learn more in the [Using the SavedModel format](https://www.tensorflow.org/guide/saved_model) guide and the [Save and load a model using a distribution strategy](https://www.tensorflow.org/tutorials/distribute/save_and_load) tutorial.)\n",
"\n",
"First, build a wrapper class to export the model. This wrapper does two things:\n",
"\n",
"- First it packages `bert_inputs_processor` and `bert_classifier` together into a single `tf.Module`, so you can export all the functionalities.\n",
"- Second it defines a `tf.function` that implements the end-to-end execution of the model.\n",
"\n",
"Setting the `input_signature` argument of `tf.function` lets you define a fixed signature for the `tf.function`. This can be less surprising than the default automatic retracing behavior."
"Congratulations! You've used `tensorflow_models` to build a BERT-classifier, train it, and export for later use."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eQceYqRFT_Eg"
},
"source": [
"## Optional: BERT on TF Hub"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QbklKt-w_CiI"
},
"source": [
"\u003ca id=\"hub_bert\"\u003e\u003c/a\u003e\n",
"\n",
"\n",
"You can get the BERT model off the shelf from [TF Hub](https://tfhub.dev/). There are [many versions available along with their input preprocessors](https://tfhub.dev/google/collections/bert/1).\n",
"\n",
"This example uses [a small version of BERT from TF Hub](https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/2) that was pre-trained using the English Wikipedia and BooksCorpus datasets, similar to the [original implementation](https://arxiv.org/abs/1908.08962) (Turc et al., 2019).\n",
"\n",
"Start by importing TF Hub:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GDWrHm0BGpbX"
},
"outputs": [],
"source": [
"import tensorflow_hub as hub"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "f02f38f83ac4"
},
"source": [
"Select the input preprocessor and the model from TF Hub and wrap them as `hub.KerasLayer` layers:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lo6479At4sP1"
},
"outputs": [],
"source": [
"# Always make sure you use the right preprocessor.\n",
"The one downside to loading this model from TF Hub is that the structure of internal Keras layers is not restored. This makes it more difficult to inspect or modify the model.\n",
"\n",
"The BERT encoder model—`hub_classifier`—is now a single layer:"
"For concrete examples of this approach, refer to [Solve Glue tasks using BERT](https://www.tensorflow.org/text/tutorials/bert_glue)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ji3tdLz101km"
},
"source": [
"## Optional: Optimizer `config`s\n",
"\n",
"The `tensorflow_models` package defines serializable `config` classes that describe how to build the live objects. Earlier in this tutorial, you built the optimizer manually.\n",
"\n",
"The configuration below describes an (almost) identical optimizer built by the `optimizer_factory.OptimizerFactory`:"
"The advantage to using `config` objects is that they don't contain any complicated TensorFlow objects, and can be easily serialized to JSON, and rebuilt. Here's the JSON for the above `tfm.optimization.OptimizationConfig`:"