Commit 2412b118 authored by Gunho Park's avatar Gunho Park
Browse files

Merge branch 'master' of https://github.com/tensorflow/models

parents f7783e7a 6dbdb08c
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "vXLA5InzXydn"
},
"source": [
"##### Copyright 2019 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "RuRlpLL-X0R_"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1mLJmVotXs64"
},
"source": [
"# Fine-tuning a BERT model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hYEwGTeCXnnX"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/nlp/fine_tune_bert\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/nlp/fine_tune_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/nlp/fine_tune_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/nlp/fine_tune_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://tfhub.dev/google/collections/bert\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/hub_logo_32px.png\" /\u003eSee TF Hub model\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YN2ACivEPxgD"
},
"source": [
"This tutorial demonstrates how to fine-tune a [Bidirectional Encoder Representations from Transformers (BERT)](https://arxiv.org/abs/1810.04805) (Devlin et al., 2018) model using [TensorFlow Model Garden](https://github.com/tensorflow/models).\n",
"\n",
"You can also find the pre-trained BERT model used in this tutorial on [TensorFlow Hub (TF Hub)](https://tensorflow.org/hub). For concrete examples of how to use the models from TF Hub, refer to the [Solve Glue tasks using BERT](https://www.tensorflow.org/text/tutorials/bert_glue) tutorial. If you're just trying to fine-tune a model, the TF Hub tutorial is a good starting point.\n",
"\n",
"On the other hand, if you're interested in deeper customization, follow this tutorial. It shows how to do a lot of things manually, so you can learn how you can customize the workflow from data preprocessing to training, exporting and saving the model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s2d9S2CSSO1z"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "69de3375e32a"
},
"source": [
"### Install pip packages"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fsACVQpVSifi"
},
"source": [
"Start by installing the TensorFlow Text and Model Garden pip packages.\n",
"\n",
"* `tf-models-official` is the TensorFlow Model Garden package. Note that it may not include the latest changes in the `tensorflow_models` GitHub repo. To include the latest changes, you may install `tf-models-nightly`, which is the nightly Model Garden package created daily automatically.\n",
"* pip will install all models and dependencies automatically."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sE6XUxLOf1s-"
},
"outputs": [],
"source": [
"!pip uninstall -y opencv-python"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yic2y7_o-BCC"
},
"outputs": [],
"source": [
"!pip install -q -U \"tensorflow-text==2.9.*\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NvNr2svBM-p3"
},
"outputs": [],
"source": [
"!pip install -q tf-models-official"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "U-7qPCjWUAyy"
},
"source": [
"### Import libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lXsXev5MNr20"
},
"outputs": [],
"source": [
"import os\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import tensorflow as tf\n",
"import tensorflow_models as tfm\n",
"import tensorflow_hub as hub\n",
"import tensorflow_datasets as tfds\n",
"tfds.disable_progress_bar()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mbanlzTvJBsz"
},
"source": [
"### Resources"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PpW0x8TpR8DT"
},
"source": [
"The following directory contains the BERT model's configuration, vocabulary, and a pre-trained checkpoint used in this tutorial:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vzRHOLciR8eq"
},
"outputs": [],
"source": [
"gs_folder_bert = \"gs://cloud-tpu-checkpoints/bert/v3/uncased_L-12_H-768_A-12\"\n",
"tf.io.gfile.listdir(gs_folder_bert)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Qv6abtRvH4xO"
},
"source": [
"## Load and preprocess the dataset\n",
"\n",
"This example uses the GLUE (General Language Understanding Evaluation) MRPC (Microsoft Research Paraphrase Corpus) [dataset from TensorFlow Datasets (TFDS)](https://www.tensorflow.org/datasets/catalog/glue#gluemrpc).\n",
"\n",
"This dataset is not set up such that it can be directly fed into the BERT model. The following section handles the necessary preprocessing."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "28DvUhC1YUiB"
},
"source": [
"### Get the dataset from TensorFlow Datasets\n",
"\n",
"The GLUE MRPC (Dolan and Brockett, 2005) dataset is a corpus of sentence pairs automatically extracted from online news sources, with human annotations for whether the sentences in the pair are semantically equivalent. It has the following attributes:\n",
"\n",
"* Number of labels: 2\n",
"* Size of training dataset: 3668\n",
"* Size of evaluation dataset: 408\n",
"* Maximum sequence length of training and evaluation dataset: 128\n",
"\n",
"Begin by loading the MRPC dataset from TFDS:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ijikx5OsH9AT"
},
"outputs": [],
"source": [
"batch_size=32\n",
"glue, info = tfds.load('glue/mrpc',\n",
" with_info=True,\n",
" batch_size=32)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QcMTJU4N7VX-"
},
"outputs": [],
"source": [
"glue"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZgBg2r2nYT-K"
},
"source": [
"The `info` object describes the dataset and its features:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IQrHxv7W7jH5"
},
"outputs": [],
"source": [
"info.features"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vhsVWYNxazz5"
},
"source": [
"The two classes are:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "n0gfc_VTayfQ"
},
"outputs": [],
"source": [
"info.features['label'].names"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "38zJcap6xkbC"
},
"source": [
"Here is one example from the training set:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xON_i6SkwApW"
},
"outputs": [],
"source": [
"example_batch = next(iter(glue['train']))\n",
"\n",
"for key, value in example_batch.items():\n",
" print(f\"{key:9s}: {value[0].numpy()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R9vEWgKA4SxV"
},
"source": [
"### Preprocess the data\n",
"\n",
"The keys `\"sentence1\"` and `\"sentence2\"` in the GLUE MRPC dataset contain two input sentences for each example.\n",
"\n",
"Because the BERT model from the Model Garden doesn't take raw text as input, two things need to happen first:\n",
"\n",
"1. The text needs to be _tokenized_ (split into word pieces) and converted to _indices_.\n",
"2. Then, the _indices_ need to be packed into the format that the model expects."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9fbTyfJpNr7x"
},
"source": [
"#### The BERT tokenizer"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wqeN54S61ZKQ"
},
"source": [
"To fine tune a pre-trained language model from the Model Garden, such as BERT, you need to make sure that you're using exactly the same tokenization, vocabulary, and index mapping as used during training.\n",
"\n",
"The following code rebuilds the tokenizer that was used by the base model using the Model Garden's `tfm.nlp.layers.FastWordpieceBertTokenizer` layer:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-DK4q5wEBmlB"
},
"outputs": [],
"source": [
"tokenizer = tfm.nlp.layers.FastWordpieceBertTokenizer(\n",
" vocab_file=os.path.join(gs_folder_bert, \"vocab.txt\"),\n",
" lower_case=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zYHDSquU2lDU"
},
"source": [
"Let's tokenize a test sentence:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "L_OfOYPg853R"
},
"outputs": [],
"source": [
"tokens = tokenizer(tf.constant([\"Hello TensorFlow!\"]))\n",
"tokens"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MfjaaMYy5Gt8"
},
"source": [
"Learn more about the tokenization process in the [Subword tokenization](https://www.tensorflow.org/text/guide/subwords_tokenizer) and [Tokenizing with TensorFlow Text](https://www.tensorflow.org/text/guide/tokenizers) guides."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wd1b09OO5GJl"
},
"source": [
"#### Pack the inputs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "62UTWLQd9-LB"
},
"source": [
"TensorFlow Model Garden's BERT model doesn't just take the tokenized strings as input. It also expects these to be packed into a particular format. `tfm.nlp.layers.BertPackInputs` layer can handle the conversion from _a list of tokenized sentences_ to the input format expected by the Model Garden's BERT model.\n",
"\n",
"`tfm.nlp.layers.BertPackInputs` packs the two input sentences (per example in the MRCP dataset) concatenated together. This input is expected to start with a `[CLS]` \"This is a classification problem\" token, and each sentence should end with a `[SEP]` \"Separator\" token.\n",
"\n",
"Therefore, the `tfm.nlp.layers.BertPackInputs` layer's constructor takes the `tokenizer`'s special tokens as an argument. It also needs to know the indices of the tokenizer's special tokens."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5iroDlrFDRcF"
},
"outputs": [],
"source": [
"special = tokenizer.get_special_tokens_dict()\n",
"special"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "b71HarkuG92H"
},
"outputs": [],
"source": [
"max_seq_length = 128\n",
"\n",
"packer = tfm.nlp.layers.BertPackInputs(\n",
" seq_length=max_seq_length,\n",
" special_tokens_dict = tokenizer.get_special_tokens_dict())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CZlSZbYd6liN"
},
"source": [
"The `packer` takes a list of tokenized sentences as input. For example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "27dU_VkJHc9S"
},
"outputs": [],
"source": [
"sentences1 = [\"hello tensorflow\"]\n",
"tok1 = tokenizer(sentences1)\n",
"tok1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LURHmNOSHnWN"
},
"outputs": [],
"source": [
"sentences2 = [\"goodbye tensorflow\"]\n",
"tok2 = tokenizer(sentences2)\n",
"tok2"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r8bvB8gI8BqP"
},
"source": [
"Then, it returns a dictionary containing three outputs:\n",
"\n",
"- `input_word_ids`: The tokenized sentences packed together.\n",
"- `input_mask`: The mask indicating which locations are valid in the other outputs.\n",
"- `input_type_ids`: Indicating which sentence each token belongs to."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YsIDTOMJHrUQ"
},
"outputs": [],
"source": [
"packed = packer([tok1, tok2])\n",
"\n",
"for key, tensor in packed.items():\n",
" print(f\"{key:15s}: {tensor[:, :12]}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "red4tRcq74Qc"
},
"source": [
"#### Put it all together\n",
"\n",
"Combine these two parts into a `keras.layers.Layer` that can be attached to your model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9Qtz-tv-6nz6"
},
"outputs": [],
"source": [
"class BertInputProcessor(tf.keras.layers.Layer):\n",
" def __init__(self, tokenizer, packer):\n",
" super().__init__()\n",
" self.tokenizer = tokenizer\n",
" self.packer = packer\n",
"\n",
" def call(self, inputs):\n",
" tok1 = self.tokenizer(inputs['sentence1'])\n",
" tok2 = self.tokenizer(inputs['sentence2'])\n",
"\n",
" packed = self.packer([tok1, tok2])\n",
"\n",
" if 'label' in inputs:\n",
" return packed, inputs['label']\n",
" else:\n",
" return packed"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rdy9wp499btU"
},
"source": [
"But for now just apply it to the dataset using `Dataset.map`, since the dataset you loaded from TFDS is a `tf.data.Dataset` object:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qmyh76AL7VAs"
},
"outputs": [],
"source": [
"bert_inputs_processor = BertInputProcessor(tokenizer, packer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "B8SSCtDe9MCk"
},
"outputs": [],
"source": [
"glue_train = glue['train'].map(bert_inputs_processor).prefetch(1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KXpiDosO9rkY"
},
"source": [
"Here is an example batch from the processed dataset:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ffNvDE6t9rP-"
},
"outputs": [],
"source": [
"example_inputs, example_labels = next(iter(glue_train))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5sxtTuUi-bXt"
},
"outputs": [],
"source": [
"example_inputs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wP4z_-9a-dFk"
},
"outputs": [],
"source": [
"example_labels"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jyjTdGpFhO_1"
},
"outputs": [],
"source": [
"for key, value in example_inputs.items():\n",
" print(f'{key:15s} shape: {value.shape}')\n",
"\n",
"print(f'{\"labels\":15s} shape: {example_labels.shape}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mkGHN_FK-50U"
},
"source": [
"The `input_word_ids` contain the token IDs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eGL1_ktWLcgF"
},
"outputs": [],
"source": [
"plt.pcolormesh(example_inputs['input_word_ids'])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ulNZ4U96-8JZ"
},
"source": [
"The mask allows the model to cleanly differentiate between the content and the padding. The mask has the same shape as the `input_word_ids`, and contains a `1` anywhere the `input_word_ids` is not padding."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zB7mW7DGK3rW"
},
"outputs": [],
"source": [
"plt.pcolormesh(example_inputs['input_mask'])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rxLenwAvCkBf"
},
"source": [
"The \"input type\" also has the same shape, but inside the non-padded region, contains a `0` or a `1` indicating which sentence the token is a part of."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2CetH_5C9P2m"
},
"outputs": [],
"source": [
"plt.pcolormesh(example_inputs['input_type_ids'])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pxHHeyei_sb9"
},
"source": [
"Apply the same preprocessing to the validation and test subsets of the GLUE MRPC dataset:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yuLKxf6zHxw-"
},
"outputs": [],
"source": [
"glue_validation = glue['validation'].map(bert_inputs_processor).prefetch(1)\n",
"glue_test = glue['test'].map(bert_inputs_processor).prefetch(1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FSwymsbkbLDA"
},
"source": [
"## Build, train and export the model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bxxO3pJCEM9p"
},
"source": [
"Now that you have formatted the data as expected, you can start working on building and training the model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Efrj3Cn1kLAp"
},
"source": [
"### Build the model\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xxpOY5r2Ayq6"
},
"source": [
"The first step is to download the configuration file—`config_dict`—for the pre-trained BERT model:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "v7ap0BONSJuz"
},
"outputs": [],
"source": [
"import json\n",
"\n",
"bert_config_file = os.path.join(gs_folder_bert, \"bert_config.json\")\n",
"config_dict = json.loads(tf.io.gfile.GFile(bert_config_file).read())\n",
"config_dict"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pKaEaKJSX85J"
},
"outputs": [],
"source": [
"encoder_config = tfm.nlp.encoders.EncoderConfig({\n",
" 'type':'bert',\n",
" 'bert': config_dict\n",
"})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LbgzWukNSqOS"
},
"outputs": [],
"source": [
"bert_encoder = tfm.nlp.encoders.build_encoder(encoder_config)\n",
"bert_encoder"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "96ldxDSwkVkj"
},
"source": [
"The configuration file defines the core BERT model from the Model Garden, which is a Keras model that predicts the outputs of `num_classes` from the inputs with maximum sequence length `max_seq_length`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cH682__U0FBv"
},
"outputs": [],
"source": [
"bert_classifier = tfm.nlp.models.BertClassifier(network=bert_encoder, num_classes=2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XqKp3-5GIZlw"
},
"source": [
"The classifier has three inputs and one output:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bAQblMIjwkvx"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(bert_classifier, show_shapes=True, dpi=48)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sFmVG4SKZAw8"
},
"source": [
"Run it on a test batch of data 10 examples from the training set. The output is the logits for the two classes:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VTjgPbp4ZDKo"
},
"outputs": [],
"source": [
"bert_classifier(\n",
" example_inputs, training=True).numpy()[:10]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Q0NTdwZsQK8n"
},
"source": [
"The `TransformerEncoder` in the center of the classifier above **is** the `bert_encoder`.\n",
"\n",
"If you inspect the encoder, notice the stack of `Transformer` layers connected to those same three inputs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8L__-erBwLIQ"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(bert_encoder, show_shapes=True, dpi=48)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mKAvkQc3heSy"
},
"source": [
"### Restore the encoder weights\n",
"\n",
"When built, the encoder is randomly initialized. Restore the encoder's weights from the checkpoint:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "97Ll2Gichd_Y"
},
"outputs": [],
"source": [
"checkpoint = tf.train.Checkpoint(encoder=bert_encoder)\n",
"checkpoint.read(\n",
" os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2oHOql35k3Dd"
},
"source": [
"Note: The pretrained `TransformerEncoder` is also available on [TensorFlow Hub](https://tensorflow.org/hub). Go to the [TF Hub appendix](#hub_bert) for details."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "115caFLMk-_l"
},
"source": [
"### Set up the optimizer\n",
"\n",
"BERT typically uses the Adam optimizer with weight decay—[AdamW](https://arxiv.org/abs/1711.05101) (`tf.keras.optimizers.experimental.AdamW`).\n",
"It also employs a learning rate schedule that first warms up from 0 and then decays to 0:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "c0jBycPDtkxR"
},
"outputs": [],
"source": [
"# Set up epochs and steps\n",
"epochs = 5\n",
"batch_size = 32\n",
"eval_batch_size = 32\n",
"\n",
"train_data_size = info.splits['train'].num_examples\n",
"steps_per_epoch = int(train_data_size / batch_size)\n",
"num_train_steps = steps_per_epoch * epochs\n",
"warmup_steps = int(0.1 * num_train_steps)\n",
"initial_learning_rate=2e-5"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GFankgHK0Rvh"
},
"source": [
"Linear decay from `initial_learning_rate` to zero over `num_train_steps`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qWSyT8P2j4mV"
},
"outputs": [],
"source": [
"linear_decay = tf.keras.optimizers.schedules.PolynomialDecay(\n",
" initial_learning_rate=initial_learning_rate,\n",
" end_learning_rate=0,\n",
" decay_steps=num_train_steps)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "anZPZPAP0Y3n"
},
"source": [
"Warmup to that value over `warmup_steps`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "z_AsVCiRkoN1"
},
"outputs": [],
"source": [
"warmup_schedule = tfm.optimization.lr_schedule.LinearWarmup(\n",
" warmup_learning_rate = 0,\n",
" after_warmup_lr_sched = linear_decay,\n",
" warmup_steps = warmup_steps\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "arfbaK6t0kH_"
},
"source": [
"The overall schedule looks like this:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rYZGunhqbGUZ"
},
"outputs": [],
"source": [
"x = tf.linspace(0, num_train_steps, 1001)\n",
"y = [warmup_schedule(xi) for xi in x]\n",
"plt.plot(x,y)\n",
"plt.xlabel('Train step')\n",
"plt.ylabel('Learning rate')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bjsmG_fm0opn"
},
"source": [
"Use `tf.keras.optimizers.experimental.AdamW` to instantiate the optimizer with that schedule:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "R8pTNuKIw1dA"
},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.experimental.Adam(\n",
" learning_rate = warmup_schedule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "78FEUOOEkoP0"
},
"source": [
"### Train the model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OTNcA0O0nSq9"
},
"source": [
"Set the metric as accuracy and the loss as sparse categorical cross-entropy. Then, compile and train the BERT classifier:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "d5FeL0b6j7ky"
},
"outputs": [],
"source": [
"metrics = [tf.keras.metrics.SparseCategoricalAccuracy('accuracy', dtype=tf.float32)]\n",
"loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
"\n",
"bert_classifier.compile(\n",
" optimizer=optimizer,\n",
" loss=loss,\n",
" metrics=metrics)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CsrylctIj_Xy"
},
"outputs": [],
"source": [
"bert_classifier.evaluate(glue_validation)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hgPPc2oNmcVZ"
},
"outputs": [],
"source": [
"bert_classifier.fit(\n",
" glue_train,\n",
" validation_data=(glue_validation),\n",
" batch_size=32,\n",
" epochs=epochs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IFtKFWbNKb0u"
},
"source": [
"Now run the fine-tuned model on a custom example to see that it works.\n",
"\n",
"Start by encoding some sentence pairs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "S1sdW6lLWaEi"
},
"outputs": [],
"source": [
"my_examples = {\n",
" 'sentence1':[\n",
" 'The rain in Spain falls mainly on the plain.',\n",
" 'Look I fine tuned BERT.'],\n",
" 'sentence2':[\n",
" 'It mostly rains on the flat lands of Spain.',\n",
" 'Is it working? This does not match.']\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7ynJibkBRTJF"
},
"source": [
"The model should report class `1` \"match\" for the first example and class `0` \"no-match\" for the second:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "umo0ttrgRYIM"
},
"outputs": [],
"source": [
"ex_packed = bert_inputs_processor(my_examples)\n",
"my_logits = bert_classifier(ex_packed, training=False)\n",
"\n",
"result_cls_ids = tf.argmax(my_logits)\n",
"result_cls_ids"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HNdmOEHKT7e8"
},
"outputs": [],
"source": [
"tf.gather(tf.constant(info.features['label'].names), result_cls_ids)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fVo_AnT0l26j"
},
"source": [
"### Export the model\n",
"\n",
"Often the goal of training a model is to _use_ it for something outside of the Python process that created it. You can do this by exporting the model using `tf.saved_model`. (Learn more in the [Using the SavedModel format](https://www.tensorflow.org/guide/saved_model) guide and the [Save and load a model using a distribution strategy](https://www.tensorflow.org/tutorials/distribute/save_and_load) tutorial.)\n",
"\n",
"First, build a wrapper class to export the model. This wrapper does two things:\n",
"\n",
"- First it packages `bert_inputs_processor` and `bert_classifier` together into a single `tf.Module`, so you can export all the functionalities.\n",
"- Second it defines a `tf.function` that implements the end-to-end execution of the model.\n",
"\n",
"Setting the `input_signature` argument of `tf.function` lets you define a fixed signature for the `tf.function`. This can be less surprising than the default automatic retracing behavior."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "78h83mlt9wpY"
},
"outputs": [],
"source": [
"class ExportModel(tf.Module):\n",
" def __init__(self, input_processor, classifier):\n",
" self.input_processor = input_processor\n",
" self.classifier = classifier\n",
"\n",
" @tf.function(input_signature=[{\n",
" 'sentence1': tf.TensorSpec(shape=[None], dtype=tf.string),\n",
" 'sentence2': tf.TensorSpec(shape=[None], dtype=tf.string)}])\n",
" def __call__(self, inputs):\n",
" packed = self.input_processor(inputs)\n",
" logits = self.classifier(packed, training=False)\n",
" result_cls_ids = tf.argmax(logits)\n",
" return {\n",
" 'logits': logits,\n",
" 'class_id': result_cls_ids,\n",
" 'class': tf.gather(\n",
" tf.constant(info.features['label'].names),\n",
" result_cls_ids)\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qnxysGUfIgFQ"
},
"source": [
"Create an instance of this export-model and save it:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TmHW9DEFUZ0X"
},
"outputs": [],
"source": [
"export_model = ExportModel(bert_inputs_processor, bert_classifier)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Nl5x6nElZqkP"
},
"outputs": [],
"source": [
"import tempfile\n",
"export_dir=tempfile.mkdtemp(suffix='_saved_model')\n",
"tf.saved_model.save(export_model, export_dir=export_dir,\n",
" signatures={'serving_default': export_model.__call__})"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Pd8B5dy-ImDJ"
},
"source": [
"Reload the model and compare the results to the original:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9cAhHySVXHD5"
},
"outputs": [],
"source": [
"original_logits = export_model(my_examples)['logits']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "H9cAcYwfW2fy"
},
"outputs": [],
"source": [
"reloaded = tf.saved_model.load(export_dir)\n",
"reloaded_logits = reloaded(my_examples)['logits']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "y_ACvKPsVUXC"
},
"outputs": [],
"source": [
"# The results are identical:\n",
"print(original_logits.numpy())\n",
"print()\n",
"print(reloaded_logits.numpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lBlPP20dXPFR"
},
"outputs": [],
"source": [
"print(np.mean(abs(original_logits - reloaded_logits)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CPsg7dZwfBM2"
},
"source": [
"Congratulations! You've used `tensorflow_models` to build a BERT-classifier, train it, and export for later use."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eQceYqRFT_Eg"
},
"source": [
"## Optional: BERT on TF Hub"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QbklKt-w_CiI"
},
"source": [
"\u003ca id=\"hub_bert\"\u003e\u003c/a\u003e\n",
"\n",
"\n",
"You can get the BERT model off the shelf from [TF Hub](https://tfhub.dev/). There are [many versions available along with their input preprocessors](https://tfhub.dev/google/collections/bert/1).\n",
"\n",
"This example uses [a small version of BERT from TF Hub](https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/2) that was pre-trained using the English Wikipedia and BooksCorpus datasets, similar to the [original implementation](https://arxiv.org/abs/1908.08962) (Turc et al., 2019).\n",
"\n",
"Start by importing TF Hub:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GDWrHm0BGpbX"
},
"outputs": [],
"source": [
"import tensorflow_hub as hub"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "f02f38f83ac4"
},
"source": [
"Select the input preprocessor and the model from TF Hub and wrap them as `hub.KerasLayer` layers:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lo6479At4sP1"
},
"outputs": [],
"source": [
"# Always make sure you use the right preprocessor.\n",
"hub_preprocessor = hub.KerasLayer(\n",
" \"https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3\")\n",
"\n",
"# This is a really small BERT.\n",
"hub_encoder = hub.KerasLayer(f\"https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/2\",\n",
" trainable=True)\n",
"\n",
"print(f\"The Hub encoder has {len(hub_encoder.trainable_variables)} trainable variables\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iTzF574wivQv"
},
"source": [
"Test run the preprocessor on a batch of data:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GOASSKR5R3-N"
},
"outputs": [],
"source": [
"hub_inputs = hub_preprocessor(['Hello TensorFlow!'])\n",
"{key: value[0, :10].numpy() for key, value in hub_inputs.items()} "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XEcYrCR45Uwo"
},
"outputs": [],
"source": [
"result = hub_encoder(\n",
" inputs=hub_inputs,\n",
" training=False,\n",
")\n",
"\n",
"print(\"Pooled output shape:\", result['pooled_output'].shape)\n",
"print(\"Sequence output shape:\", result['sequence_output'].shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cjojn8SmLSRI"
},
"source": [
"At this point it would be simple to add a classification head yourself.\n",
"\n",
"The Model Garden `tfm.nlp.models.BertClassifier` class can also build a classifier onto the TF Hub encoder:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9nTDaApyLR70"
},
"outputs": [],
"source": [
"hub_classifier = tfm.nlp.models.BertClassifier(\n",
" bert_encoder,\n",
" num_classes=2,\n",
" dropout_rate=0.1,\n",
" initializer=tf.keras.initializers.TruncatedNormal(\n",
" stddev=0.02))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xMJX3wV0_v7I"
},
"source": [
"The one downside to loading this model from TF Hub is that the structure of internal Keras layers is not restored. This makes it more difficult to inspect or modify the model.\n",
"\n",
"The BERT encoder model—`hub_classifier`—is now a single layer:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pD71dnvhM2QS"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(hub_classifier, show_shapes=True, dpi=64)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "u_IqwXjRV1vd"
},
"source": [
"For concrete examples of this approach, refer to [Solve Glue tasks using BERT](https://www.tensorflow.org/text/tutorials/bert_glue)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ji3tdLz101km"
},
"source": [
"## Optional: Optimizer `config`s\n",
"\n",
"The `tensorflow_models` package defines serializable `config` classes that describe how to build the live objects. Earlier in this tutorial, you built the optimizer manually.\n",
"\n",
"The configuration below describes an (almost) identical optimizer built by the `optimizer_factory.OptimizerFactory`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Fdb9C1ontnH_"
},
"outputs": [],
"source": [
"optimization_config = tfm.optimization.OptimizationConfig(\n",
" optimizer=tfm.optimization.OptimizerConfig(\n",
" type = \"adam\"),\n",
" learning_rate = tfm.optimization.LrConfig(\n",
" type='polynomial',\n",
" polynomial=tfm.optimization.PolynomialLrConfig(\n",
" initial_learning_rate=2e-5,\n",
" end_learning_rate=0.0,\n",
" decay_steps=num_train_steps)),\n",
" warmup = tfm.optimization.WarmupConfig(\n",
" type='linear',\n",
" linear=tfm.optimization.LinearWarmupConfig(warmup_steps=warmup_steps)\n",
" ))\n",
"\n",
"\n",
"fac = tfm.optimization.optimizer_factory.OptimizerFactory(optimization_config)\n",
"lr = fac.build_learning_rate()\n",
"optimizer = fac.build_optimizer(lr=lr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Rp7R1hBfv5HG"
},
"outputs": [],
"source": [
"x = tf.linspace(0, num_train_steps, 1001).numpy()\n",
"y = [lr(xi) for xi in x]\n",
"plt.plot(x,y)\n",
"plt.xlabel('Train step')\n",
"plt.ylabel('Learning rate')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ywn5miD_dnuh"
},
"source": [
"The advantage to using `config` objects is that they don't contain any complicated TensorFlow objects, and can be easily serialized to JSON, and rebuilt. Here's the JSON for the above `tfm.optimization.OptimizationConfig`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zo5RV5lud81Y"
},
"outputs": [],
"source": [
"optimization_config = optimization_config.as_dict()\n",
"optimization_config"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Z6qPXPEhekkd"
},
"source": [
"The `tfm.optimization.optimizer_factory.OptimizerFactory` can just as easily build the optimizer from the JSON dictionary:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "p-bYrvfMYsxp"
},
"outputs": [],
"source": [
"fac = tfm.optimization.optimizer_factory.OptimizerFactory(\n",
" tfm.optimization.OptimizationConfig(optimization_config))\n",
"lr = fac.build_learning_rate()\n",
"optimizer = fac.build_optimizer(lr=lr)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "fine_tune_bert.ipynb",
"private_outputs": true,
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
......@@ -13,6 +13,7 @@
# limitations under the License.
"""Core is shared by both `nlp` and `vision`."""
from official.core import actions
from official.core import base_task
from official.core import base_trainer
......@@ -21,6 +22,7 @@ from official.core import exp_factory
from official.core import export_base
from official.core import input_reader
from official.core import registry
from official.core import savedmodel_checkpoint_manager
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom checkpoint manager that also exports saved models."""
import os
import re
from typing import Callable, Mapping, Optional
from absl import logging
import tensorflow as tf
def make_saved_modules_directory_name(checkpoint_name: str) -> str:
return f'{checkpoint_name}_saved_modules'
class SavedModelCheckpointManager(tf.train.CheckpointManager):
"""A CheckpointManager that also exports `SavedModel`s."""
def __init__(self,
checkpoint: tf.train.Checkpoint,
directory: str,
max_to_keep: int,
modules_to_export: Optional[Mapping[str, tf.Module]] = None,
keep_checkpoint_every_n_hours: Optional[int] = None,
checkpoint_name: str = 'ckpt',
step_counter: Optional[tf.Variable] = None,
checkpoint_interval: Optional[int] = None,
init_fn: Optional[Callable[[], None]] = None):
"""See base class."""
super().__init__(
checkpoint=checkpoint,
directory=directory,
max_to_keep=max_to_keep,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
checkpoint_name=checkpoint_name,
step_counter=step_counter,
checkpoint_interval=checkpoint_interval,
init_fn=init_fn)
self._modules_to_export = modules_to_export
def save(self,
checkpoint_number=None,
check_interval: bool = True,
options: Optional[tf.train.CheckpointOptions] = None):
"""See base class."""
checkpoint_path = super().save(
checkpoint_number=checkpoint_number,
check_interval=check_interval,
options=options)
if not checkpoint_path: # Nothing got written.
return
if not self._modules_to_export: # No modules to export.
logging.info('Skip saving SavedModel due to empty modules_to_export.')
return checkpoint_path
# Save the models for the checkpoint that just got written.
saved_modules_directory = make_saved_modules_directory_name(checkpoint_path)
for model_name, model in self._modules_to_export.items():
tf.saved_model.save(
obj=model,
export_dir=os.path.join(saved_modules_directory, model_name))
# `checkpoint_path` ends in `-[\d]+`. We want to glob for all existing
# checkpoints, and we use the .index file for that.
checkpoint_glob = re.sub(r'\d+$', '*.index', checkpoint_path)
existing_checkpoint_files = tf.io.gfile.glob(checkpoint_glob)
saved_modules_directories_to_keep = [
make_saved_modules_directory_name(os.path.splitext(ckpt_index)[0])
for ckpt_index in existing_checkpoint_files
]
saved_modules_glob = re.sub(r'\d+_saved_modules$', '*_saved_modules',
saved_modules_directory)
for existing_saved_modules_dir in tf.io.gfile.glob(saved_modules_glob):
if (existing_saved_modules_dir not in saved_modules_directories_to_keep
and tf.io.gfile.isdir(existing_saved_modules_dir)):
tf.io.gfile.rmtree(existing_saved_modules_dir)
return checkpoint_path
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Iterable
import tensorflow as tf
from official.core import savedmodel_checkpoint_manager
def _models_exist(checkpoint_path: str, models: Iterable[str]) -> bool:
for model_name in models:
if not tf.io.gfile.isdir(
os.path.join(
savedmodel_checkpoint_manager.make_saved_modules_directory_name(
checkpoint_path), model_name)):
return False
return True
class CheckpointManagerTest(tf.test.TestCase):
def testSimpleTest(self):
models = {
"model_1":
tf.keras.Sequential(
layers=[tf.keras.layers.Dense(8, input_shape=(16,))]),
"model_2":
tf.keras.Sequential(
layers=[tf.keras.layers.Dense(16, input_shape=(32,))]),
}
checkpoint = tf.train.Checkpoint()
manager = savedmodel_checkpoint_manager.SavedModelCheckpointManager(
checkpoint=checkpoint,
directory=self.get_temp_dir(),
max_to_keep=1,
modules_to_export=models)
first_path = manager.save()
second_path = manager.save()
self.assertTrue(_models_exist(second_path, models.keys()))
self.assertFalse(_models_exist(first_path, models.keys()))
if __name__ == "__main__":
tf.test.main()
......@@ -32,6 +32,226 @@ from official.core import train_utils
maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter
class OrbitExperimentRunner:
"""Runs experiment with Orbit training loop.
The default experiment runner for model garden experiments. User can
customize the experiment pipeline by subclassing this class and replacing
components or functions.
For example, an experiment runner with customized checkpoint manager:
```python
class MyExpRunnerWithExporter(AbstractExperimentRunner):
def _maybe_build_checkpoint_manager(sefl):
return MyCheckpointManager(*args)
# In user code
MyExpRunnerWithExporter(**needed_kwargs).run(mode)
```
Similar override can be done to other components.
"""
def __init__(
self,
distribution_strategy: tf.distribute.Strategy,
task: base_task.Task,
mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True,
train_actions: Optional[List[orbit.Action]] = None,
eval_actions: Optional[List[orbit.Action]] = None,
trainer: Optional[base_trainer.Trainer] = None,
controller_cls=orbit.Controller
):
"""Constructor.
Args:
distribution_strategy: A distribution strategy.
task: A Task instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval',
'train_and_eval' or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
train_actions: Optional list of Orbit train actions.
eval_actions: Optional list of Orbit eval actions.
trainer: the base_trainer.Trainer instance. It should be created within
the strategy.scope().
controller_cls: The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
"""
self.strategy = distribution_strategy or tf.distribute.get_strategy()
self._params = params
self._model_dir = model_dir
self._mode = mode
self._run_post_eval = run_post_eval
self._trainer = trainer or self._build_trainer(
task,
train='train' in mode,
evaluate=('eval' in mode) or run_post_eval)
assert self.trainer is not None
self._checkpoint_manager = self._maybe_build_checkpoint_manager()
self._controller = self._build_controller(
trainer=self.trainer if 'train' in mode else None,
evaluator=self.trainer,
save_summary=save_summary,
train_actions=train_actions,
eval_actions=eval_actions,
controller_cls=controller_cls)
@property
def params(self) -> config_definitions.ExperimentConfig:
return self._params
@property
def model_dir(self) -> str:
return self._model_dir
@property
def trainer(self) -> base_trainer.Trainer:
return self._trainer
@property
def checkpoint_manager(self) -> tf.train.CheckpointManager:
return self._checkpoint_manager
@property
def controller(self) -> orbit.Controller:
return self._controller
def _build_trainer(self, task: base_task.Task, train: bool,
evaluate: bool) -> base_trainer.Trainer:
"""Create trainer."""
with self.strategy.scope():
trainer = train_utils.create_trainer(
self.params,
task,
train=train,
evaluate=evaluate,
checkpoint_exporter=self._build_best_checkpoint_exporter())
return trainer
def _build_best_checkpoint_exporter(self):
return maybe_create_best_ckpt_exporter(self.params, self.model_dir)
def _maybe_build_checkpoint_manager(
self) -> Optional[tf.train.CheckpointManager]:
"""Maybe create a CheckpointManager."""
assert self.trainer is not None
if self.trainer.checkpoint:
if self.model_dir is None:
raise ValueError('model_dir must be specified, but got None')
checkpoint_manager = tf.train.CheckpointManager(
self.trainer.checkpoint,
directory=self.model_dir,
max_to_keep=self.params.trainer.max_to_keep,
step_counter=self.trainer.global_step,
checkpoint_interval=self.params.trainer.checkpoint_interval,
init_fn=self.trainer.initialize)
else:
checkpoint_manager = None
return checkpoint_manager
def _build_controller(self,
trainer,
evaluator,
save_summary: bool = True,
train_actions: Optional[List[orbit.Action]] = None,
eval_actions: Optional[List[orbit.Action]] = None,
controller_cls=orbit.Controller) -> orbit.Controller:
"""Builds a Orbit controler."""
train_actions = [] if not train_actions else train_actions
if trainer:
train_actions += actions.get_train_actions(
self.params,
trainer,
self.model_dir,
checkpoint_manager=self.checkpoint_manager)
eval_actions = [] if not eval_actions else eval_actions
if evaluator:
eval_actions += actions.get_eval_actions(self.params, evaluator,
self.model_dir)
controller = controller_cls(
strategy=self.strategy,
trainer=trainer,
evaluator=evaluator,
global_step=self.trainer.global_step,
steps_per_loop=self.params.trainer.steps_per_loop,
checkpoint_manager=self.checkpoint_manager,
summary_dir=os.path.join(self.model_dir, 'train') if
(save_summary) else None,
eval_summary_dir=os.path.join(
self.model_dir, self.params.trainer.validation_summary_subdir) if
(save_summary) else None,
summary_interval=self.params.trainer.summary_interval if
(save_summary) else None,
train_actions=train_actions,
eval_actions=eval_actions)
return controller
def run(self) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Run experiments by mode.
Returns:
A 2-tuple of (model, eval_logs).
model: `tf.keras.Model` instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
"""
mode = self._mode
params = self.params
logging.info('Starts to execute mode: %s', mode)
with self.strategy.scope():
if mode == 'train' or mode == 'train_and_post_eval':
self.controller.train(steps=params.trainer.train_steps)
elif mode == 'train_and_eval':
self.controller.train_and_evaluate(
train_steps=params.trainer.train_steps,
eval_steps=params.trainer.validation_steps,
eval_interval=params.trainer.validation_interval)
elif mode == 'eval':
self.controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval':
def timeout_fn():
if self.trainer.global_step.numpy() >= params.trainer.train_steps:
return True
return False
self.controller.evaluate_continuously(
steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn)
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
num_params = train_utils.try_count_params(self.trainer.model)
if num_params is not None:
logging.info('Number of trainable params in model: %f Millions.',
num_params / 10.**6)
flops = train_utils.try_count_flops(self.trainer.model)
if flops is not None:
logging.info('FLOPs (multi-adds) in model: %f Billions.',
flops / 10.**9 / 2)
if self._run_post_eval or mode == 'train_and_post_eval':
with self.strategy.scope():
return self.trainer.model, self.controller.evaluate(
steps=params.trainer.validation_steps)
else:
return self.trainer.model, {}
def run_experiment(
distribution_strategy: tf.distribute.Strategy,
task: base_task.Task,
......@@ -70,91 +290,17 @@ def run_experiment(
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
"""
with distribution_strategy.scope():
if not trainer:
trainer = train_utils.create_trainer(
params,
task,
train='train' in mode,
evaluate=('eval' in mode) or run_post_eval,
checkpoint_exporter=maybe_create_best_ckpt_exporter(
params, model_dir))
if trainer.checkpoint:
if model_dir is None:
raise ValueError('model_dir must be specified, but got None')
checkpoint_manager = tf.train.CheckpointManager(
trainer.checkpoint,
directory=model_dir,
max_to_keep=params.trainer.max_to_keep,
step_counter=trainer.global_step,
checkpoint_interval=params.trainer.checkpoint_interval,
init_fn=trainer.initialize)
else:
checkpoint_manager = None
train_actions = [] if not train_actions else train_actions
train_actions += actions.get_train_actions(
params, trainer, model_dir, checkpoint_manager=checkpoint_manager)
eval_actions = [] if not eval_actions else eval_actions
eval_actions += actions.get_eval_actions(params, trainer, model_dir)
controller = controller_cls(
strategy=distribution_strategy,
trainer=trainer if 'train' in mode else None,
evaluator=trainer,
global_step=trainer.global_step,
steps_per_loop=params.trainer.steps_per_loop,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(model_dir, 'train') if (save_summary) else None,
eval_summary_dir=os.path.join(model_dir,
params.trainer.validation_summary_subdir) if
(save_summary) else None,
summary_interval=params.trainer.summary_interval if
(save_summary) else None,
runner = OrbitExperimentRunner(
distribution_strategy=distribution_strategy,
task=task,
mode=mode,
params=params,
model_dir=model_dir,
run_post_eval=run_post_eval,
save_summary=save_summary,
train_actions=train_actions,
eval_actions=eval_actions)
logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope():
if mode == 'train' or mode == 'train_and_post_eval':
controller.train(steps=params.trainer.train_steps)
elif mode == 'train_and_eval':
controller.train_and_evaluate(
train_steps=params.trainer.train_steps,
eval_steps=params.trainer.validation_steps,
eval_interval=params.trainer.validation_interval)
elif mode == 'eval':
controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval':
def timeout_fn():
if trainer.global_step.numpy() >= params.trainer.train_steps:
return True
return False
controller.evaluate_continuously(
steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn)
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
num_params = train_utils.try_count_params(trainer.model)
if num_params is not None:
logging.info('Number of trainable params in model: %f Millions.',
num_params / 10.**6)
flops = train_utils.try_count_flops(trainer.model)
if flops is not None:
logging.info('FLOPs (multi-adds) in model: %f Billions.',
flops / 10.**9 / 2)
if run_post_eval or mode == 'train_and_post_eval':
with distribution_strategy.scope():
return trainer.model, controller.evaluate(
steps=params.trainer.validation_steps)
else:
return trainer.model, {}
eval_actions=eval_actions,
trainer=trainer,
controller_cls=controller_cls,
)
return runner.run()
......@@ -117,6 +117,61 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
model_dir=model_dir,
run_post_eval=run_post_eval)
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
flag_mode=['train', 'eval', 'train_and_eval'],
run_post_eval=[True, False]))
def test_end_to_end_class(self, distribution_strategy, flag_mode,
run_post_eval):
model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
mode=flag_mode,
model_dir=model_dir,
params_override=json.dumps(self._test_config))
with flagsaver.flagsaver(**flags_dict):
params = train_utils.parse_configuration(flags.FLAGS)
train_utils.serialize_config(params, model_dir)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
_, logs = train_lib.OrbitExperimentRunner(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=params,
model_dir=model_dir,
run_post_eval=run_post_eval).run()
if 'eval' in flag_mode:
self.assertTrue(
tf.io.gfile.exists(
os.path.join(model_dir,
params.trainer.validation_summary_subdir)))
if run_post_eval:
self.assertNotEmpty(logs)
else:
self.assertEmpty(logs)
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml')))
if flag_mode == 'eval':
return
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
# Tests continuous evaluation.
_, logs = train_lib.OrbitExperimentRunner(
distribution_strategy=distribution_strategy,
task=task,
mode='continuous_eval',
params=params,
model_dir=model_dir,
run_post_eval=run_post_eval).run()
@combinations.generate(
combinations.combine(
distribution_strategy=[
......@@ -148,12 +203,12 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
task.build_losses = build_losses
with self.assertRaises(RuntimeError):
train_lib.run_experiment(
train_lib.OrbitExperimentRunner(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=params,
model_dir=model_dir)
model_dir=model_dir).run()
@combinations.generate(
combinations.combine(
......@@ -194,12 +249,12 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
task.build_losses = build_losses
model, _ = train_lib.run_experiment(
model, _ = train_lib.OrbitExperimentRunner(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=params,
model_dir=model_dir)
model_dir=model_dir).run()
after_weights = model.get_weights()
for left, right in zip(before_weights, after_weights):
self.assertAllEqual(left, right)
......
......@@ -19,6 +19,7 @@ import gin
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import util
@tf.keras.utils.register_keras_serializable(package="Text")
......@@ -57,9 +58,9 @@ class GatedFeedforward(tf.keras.layers.Layer):
"""
def __init__(self,
intermediate_size,
intermediate_activation,
dropout,
inner_dim=768,
inner_activation=tf_utils.get_activation("gelu"),
dropout=0.0,
use_gate=True,
apply_output_layer_norm=True,
num_blocks=1,
......@@ -72,9 +73,12 @@ class GatedFeedforward(tf.keras.layers.Layer):
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(GatedFeedforward, self).__init__(**kwargs)
self._intermediate_size = intermediate_size
self._intermediate_activation = intermediate_activation
inner_dim = kwargs.pop("intermediate_size", inner_dim)
inner_activation = kwargs.pop("intermediate_activation", inner_activation)
util.filter_kwargs(kwargs)
super().__init__(**kwargs)
self._inner_dim = inner_dim
self._inner_activation = inner_activation
self._dropout = dropout
self._use_gate = use_gate
self._num_blocks = num_blocks
......@@ -103,7 +107,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
self._intermediate_dense = []
self._intermediate_activation_layers = []
self._inner_activation_layers = []
self._gate_dense = []
self._output_dense = []
self._output_dropout = []
......@@ -118,7 +122,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
self._intermediate_dense.append(
tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, self._intermediate_size),
output_shape=(None, self._inner_dim),
bias_axes="d",
name="intermediate_%d" % i,
kernel_initializer=tf_utils.clone_initializer(
......@@ -126,14 +130,14 @@ class GatedFeedforward(tf.keras.layers.Layer):
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs))
self._intermediate_activation_layers.append(
self._inner_activation_layers.append(
tf.keras.layers.Activation(
self._intermediate_activation, dtype=activation_policy))
self._inner_activation, dtype=activation_policy))
if self._use_gate:
self._gate_dense.append(
tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, self._intermediate_size),
output_shape=(None, self._inner_dim),
bias_axes="d",
name="gate_%d" % i,
kernel_initializer=tf_utils.clone_initializer(
......@@ -164,10 +168,10 @@ class GatedFeedforward(tf.keras.layers.Layer):
def get_config(self):
config = {
"intermediate_size":
self._intermediate_size,
"intermediate_activation":
self._intermediate_activation,
"inner_dim":
self._inner_dim,
"inner_activation":
self._inner_activation,
"dropout":
self._dropout,
"use_gate":
......@@ -191,7 +195,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint)
}
base_config = super(GatedFeedforward, self).get_config()
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
......@@ -199,7 +203,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
for i in range(self._num_blocks):
layer_input = layer_output
intermediate_output = self._intermediate_dense[i](layer_input)
intermediate_output = self._intermediate_activation_layers[i](
intermediate_output = self._inner_activation_layers[i](
intermediate_output)
if self._use_gate:
gated_linear = self._gate_dense[i](layer_input)
......
......@@ -44,8 +44,8 @@ class GatedFeedforwardTest(keras_parameterized.TestCase):
def test_layer_creation(self, use_gate, num_blocks, dropout_position, dtype):
tf.keras.mixed_precision.set_global_policy(dtype)
kwargs = dict(
intermediate_size=128,
intermediate_activation="relu",
inner_dim=128,
inner_activation="relu",
dropout=0.1,
use_gate=use_gate,
num_blocks=num_blocks,
......@@ -76,8 +76,8 @@ class GatedFeedforwardTest(keras_parameterized.TestCase):
dtype):
tf.keras.mixed_precision.set_global_policy(dtype)
kwargs = dict(
intermediate_size=16,
intermediate_activation="relu",
inner_dim=16,
inner_activation="relu",
dropout=0.1,
use_gate=use_gate,
num_blocks=num_blocks,
......@@ -104,8 +104,8 @@ class GatedFeedforwardTest(keras_parameterized.TestCase):
def test_serialize_deserialize(self):
kwargs = dict(
intermediate_size=16,
intermediate_activation="relu",
inner_dim=16,
inner_activation="relu",
dropout=0.1,
use_gate=False,
num_blocks=4,
......
......@@ -76,7 +76,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_dropout_rate)
dropout_rate = kwargs.pop("output_dropout", dropout_rate)
inner_dim = kwargs.pop("intermediate_size", inner_dim)
inner_activation = kwargs.pop("inner_activation", inner_activation)
inner_activation = kwargs.pop("intermediate_activation", inner_activation)
util.filter_kwargs(kwargs)
super().__init__(**kwargs)
......
......@@ -19,8 +19,6 @@ import os
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.keras.testing_utils import layer_test
from official.nlp.modeling.layers.tn_expand_condense import TNExpandCondense
......@@ -45,13 +43,9 @@ class TNLayerTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters((768, 6), (1024, 2))
def test_keras_layer(self, input_dim, proj_multiple):
self.skipTest('Disable the test for now since it imports '
'keras.testing_utils, will reenable this test after we '
'fix the b/184578869')
# TODO(scottzhu): Reenable after fix b/184578869
data = np.random.normal(size=(100, input_dim))
data = data.astype(np.float32)
layer_test(
tf.keras.__internal__.utils.layer_test(
TNExpandCondense,
kwargs={
'proj_multiplier': proj_multiple,
......@@ -64,9 +58,9 @@ class TNLayerTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters((768, 6), (1024, 2))
def test_train(self, input_dim, proj_multiple):
tf.keras.utils.set_random_seed(0)
data = np.random.randint(10, size=(100, input_dim))
model = self._build_model(data, proj_multiple)
tf.keras.utils.set_random_seed(0)
model.compile(
optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
......
......@@ -75,7 +75,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
E.g. let's say input dims are `[batch_size, seq_dim, input_last_dim]`.
Scenario 1: If `output_last_dim` is not `None`, then the output dims of this
module would be `[batch_size, seq_dim, output_last_dim]`. Note `key_dim` is
is overriden by `output_last_dim`.
overriden by `output_last_dim`.
Scenario 2: If `output_last_dim` is `None` and `key_dim` is not `None`, then
the output dims of this module would be `[batch_size, seq_dim, key_dim]`.
Scenario 3: If the `output_last_dim` and `key_dim` are both `None`, the
......
......@@ -124,6 +124,7 @@ class RetinaNetHeadQuantized(tf.keras.layers.Layer):
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
num_params_per_anchor: int = 4,
share_classification_heads: bool = False,
**kwargs):
"""Initializes a RetinaNet quantized head.
......@@ -156,8 +157,13 @@ class RetinaNetHeadQuantized(tf.keras.layers.Layer):
box. For example, `num_params_per_anchor` would be 4 for axis-aligned
anchor boxes specified by their y-centers, x-centers, heights, and
widths.
share_classification_heads: A `bool` that indicates whethere
sharing weights among the main and attribute classification heads. Not
used in the QAT model.
**kwargs: Additional keyword arguments to be passed.
"""
del share_classification_heads
super().__init__(**kwargs)
self._config_dict = {
'min_level': min_level,
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Export modules for QAT model serving/inference."""
import tensorflow as tf
from official.projects.qat.vision.modeling import factory as qat_factory
from official.vision.serving import image_classification
from official.vision.serving import semantic_segmentation
class ClassificationModule(image_classification.ClassificationModule):
"""Classification Module."""
def _build_model(self):
model = super()._build_model()
input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
self._input_image_size + [3])
return qat_factory.build_qat_classification_model(
model, self.params.task.quantization, input_specs,
self.params.task.model)
class SegmentationModule(semantic_segmentation.SegmentationModule):
"""Segmentation Module."""
def _build_model(self):
model = super()._build_model()
input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
self._input_image_size + [3])
return qat_factory.build_qat_segmentation_model(
model, self.params.task.quantization, input_specs)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Vision models export binary for serving/inference.
To export a trained checkpoint in saved_model format (shell script):
EXPERIMENT_TYPE = XX
CHECKPOINT_PATH = XX
EXPORT_DIR_PATH = XX
export_saved_model --experiment=${EXPERIMENT_TYPE} \
--export_dir=${EXPORT_DIR_PATH}/ \
--checkpoint_path=${CHECKPOINT_PATH} \
--batch_size=2 \
--input_image_size=224,224
To serve (python):
export_dir_path = XX
input_type = XX
input_images = XX
imported = tf.saved_model.load(export_dir_path)
model_fn = imported.signatures['serving_default']
output = model_fn(input_images)
"""
from absl import app
from absl import flags
from official.core import exp_factory
from official.modeling import hyperparams
from official.projects.qat.vision import registry_imports # pylint: disable=unused-import
from official.projects.qat.vision.serving import export_module
from official.vision import configs
from official.vision.serving import export_saved_model_lib
FLAGS = flags.FLAGS
_EXPERIMENT = flags.DEFINE_string(
'experiment', None, 'experiment type, e.g. retinanet_resnetfpn_coco')
_EXPORT_DIR = flags.DEFINE_string('export_dir', None, 'The export directory.')
_CHECKPOINT_PATH = flags.DEFINE_string('checkpoint_path', None,
'Checkpoint path.')
_CONFIG_FILE = flags.DEFINE_multi_string(
'config_file',
default=None,
help='YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
_PARAMS_OVERRIDE = flags.DEFINE_string(
'params_override', '',
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.')
_BATCH_SIZSE = flags.DEFINE_integer('batch_size', None, 'The batch size.')
_IMAGE_TYPE = flags.DEFINE_string(
'input_type', 'image_tensor',
'One of `image_tensor`, `image_bytes`, `tf_example` and `tflite`.')
_INPUT_IMAGE_SIZE = flags.DEFINE_string(
'input_image_size', '224,224',
'The comma-separated string of two integers representing the height,width '
'of the input to the model.')
_EXPORT_CHECKPOINT_SUBDIR = flags.DEFINE_string(
'export_checkpoint_subdir', 'checkpoint',
'The subdirectory for checkpoints.')
_EXPORT_SAVED_MODEL_SUBDIR = flags.DEFINE_string(
'export_saved_model_subdir', 'saved_model',
'The subdirectory for saved model.')
_LOG_MODEL_FLOPS_AND_PARAMS = flags.DEFINE_bool(
'log_model_flops_and_params', False,
'If true, logs model flops and parameters.')
_INPUT_NAME = flags.DEFINE_string(
'input_name', None,
'Input tensor name in signature def. Default at None which'
'produces input tensor name `inputs`.')
def main(_):
params = exp_factory.get_exp_config(_EXPERIMENT.value)
for config_file in _CONFIG_FILE.value or []:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
if _PARAMS_OVERRIDE.value:
params = hyperparams.override_params_dict(
params, _PARAMS_OVERRIDE.value, is_strict=True)
params.validate()
params.lock()
input_image_size = [int(x) for x in _INPUT_IMAGE_SIZE.value.split(',')]
if isinstance(params.task,
configs.image_classification.ImageClassificationTask):
export_module_cls = export_module.ClassificationModule
elif isinstance(params.task,
configs.semantic_segmentation.SemanticSegmentationTask):
export_module_cls = export_module.SegmentationModule
else:
raise TypeError(f'Export module for {type(params.task)} is not supported.')
module = export_module_cls(
params=params,
batch_size=_BATCH_SIZSE.value,
input_image_size=input_image_size,
input_type=_IMAGE_TYPE.value,
num_channels=3)
export_saved_model_lib.export_inference_graph(
input_type=_IMAGE_TYPE.value,
batch_size=_BATCH_SIZSE.value,
input_image_size=input_image_size,
params=params,
checkpoint_path=_CHECKPOINT_PATH.value,
export_dir=_EXPORT_DIR.value,
export_checkpoint_subdir=_EXPORT_CHECKPOINT_SUBDIR.value,
export_saved_model_subdir=_EXPORT_SAVED_MODEL_SUBDIR.value,
export_module=module,
log_model_flops_and_params=_LOG_MODEL_FLOPS_AND_PARAMS.value,
input_name=_INPUT_NAME.value)
if __name__ == '__main__':
app.run(main)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Binary to convert a saved model to TFLite model for the QAT model."""
from absl import app
from official.projects.qat.vision import registry_imports # pylint: disable=unused-import
from official.vision.serving import export_tflite
if __name__ == '__main__':
app.run(export_tflite.main)
......@@ -95,7 +95,7 @@ class UNet3DDecoder(tf.keras.Model):
channel_dim = 1
# Build 3D UNet.
inputs = self._build_input_pyramid(input_specs, model_id)
inputs = self._build_input_pyramid(input_specs, model_id) # pytype: disable=wrong-arg-types # dynamic-method-lookup
# Add levels with up-convolution or up-sampling.
x = inputs[str(model_id)]
......
......@@ -88,6 +88,7 @@ class DataConfig(cfg.DataConfig):
def yt8m(is_training):
"""YT8M dataset configs."""
# pylint: disable=unexpected-keyword-arg
return DataConfig(
num_frames=30,
temporal_stride=1,
......@@ -95,8 +96,10 @@ def yt8m(is_training):
segment_size=5,
is_training=is_training,
split='train' if is_training else 'valid',
drop_remainder=is_training, # pytype: disable=wrong-keyword-args
num_examples=YT8M_TRAIN_EXAMPLES if is_training else YT8M_VAL_EXAMPLES,
input_path=YT8M_TRAIN_PATH if is_training else YT8M_VAL_PATH)
# pylint: enable=unexpected-keyword-arg
@dataclasses.dataclass
......
......@@ -22,7 +22,6 @@
back into a range between min_quantized_value and max_quantized_value.
link for details: https://research.google.com/youtube8m/download.html
"""
from typing import Dict
import tensorflow as tf
......@@ -424,8 +423,9 @@ class PostBatchProcessor():
[-1, self.num_classes])
else:
video_matrix = tf.squeeze(video_matrix)
labels = tf.squeeze(labels)
# NOTE(b/237445211): Must provide axis argument to tf.squeeze.
video_matrix = tf.squeeze(video_matrix, axis=1)
labels = tf.squeeze(labels, axis=1)
batched_tensors = {
"video_matrix": video_matrix,
......@@ -449,13 +449,15 @@ class TransformBatcher():
self._global_batch_size = input_params.global_batch_size
self._is_training = input_params.is_training
self._include_video_id = input_params.include_video_id
self._drop_remainder = input_params.drop_remainder
def batch_fn(self, dataset, input_context):
"""Add padding when segment_labels is true."""
per_replica_batch_size = input_context.get_per_replica_batch_size(
self._global_batch_size) if input_context else self._global_batch_size
if not self._segment_labels:
dataset = dataset.batch(per_replica_batch_size, drop_remainder=True)
dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._drop_remainder)
else:
# add padding
pad_shapes = {
......@@ -476,6 +478,6 @@ class TransformBatcher():
dataset = dataset.padded_batch(
per_replica_batch_size,
padded_shapes=pad_shapes,
drop_remainder=True,
drop_remainder=self._drop_remainder,
padding_values=pad_values)
return dataset
......@@ -13,6 +13,7 @@
# limitations under the License.
"""Provides functions to help with evaluating models."""
import logging
import numpy as np
import tensorflow as tf
from official.projects.yt8m.eval_utils import average_precision_calculator as ap_calculator
......@@ -57,6 +58,9 @@ def calculate_precision_at_equal_recall_rate(predictions, actuals):
"""
aggregated_precision = 0.0
num_videos = actuals.shape[0]
if num_videos == 0:
logging.warning("Num_videos is 0, returning 0.0 aggregated_precision.")
return aggregated_precision
for row in np.arange(num_videos):
num_labels = int(np.sum(actuals[row]))
top_indices = np.argpartition(predictions[row], -num_labels)[-num_labels:]
......
......@@ -180,7 +180,7 @@ class PanopticDeeplabTask(cfg.TaskConfig):
@exp_factory.register_config_factory('panoptic_deeplab_resnet_coco')
def panoptic_deeplab_coco() -> cfg.ExperimentConfig:
def panoptic_deeplab_resnet_coco() -> cfg.ExperimentConfig:
"""COCO panoptic segmentation with Panoptic Deeplab."""
train_steps = 200000
train_batch_size = 64
......@@ -344,3 +344,327 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig:
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('panoptic_deeplab_mobilenetv3_large_coco')
def panoptic_deeplab_mobilenetv3_large_coco() -> cfg.ExperimentConfig:
"""COCO panoptic segmentation with Panoptic Deeplab."""
train_steps = 200000
train_batch_size = 64
eval_batch_size = 1
steps_per_epoch = _COCO_TRAIN_EXAMPLES // train_batch_size
validation_steps = _COCO_VAL_EXAMPLES // eval_batch_size
num_panoptic_categories = 201
num_thing_categories = 91
ignore_label = 0
is_thing = [False]
for idx in range(1, num_panoptic_categories):
is_thing.append(True if idx <= num_thing_categories else False)
input_size = [640, 640, 3]
output_stride = 16
aspp_dilation_rates = [6, 12, 18]
level = int(np.math.log2(output_stride))
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(
mixed_precision_dtype='float32', enable_xla=True),
task=PanopticDeeplabTask(
init_checkpoint='gs://tf_model_garden/vision/panoptic/panoptic_deeplab/imagenet/mobilenetv3_large/ckpt-156000',
init_checkpoint_modules=['backbone'],
model=PanopticDeeplab(
num_classes=num_panoptic_categories,
input_size=input_size,
backbone=backbones.Backbone(
type='mobilenet', mobilenet=backbones.MobileNet(
model_id='MobileNetV3Large',
filter_size_scale=1.0,
stochastic_depth_drop_rate=0.0,
output_stride=output_stride)),
decoder=decoders.Decoder(
type='aspp',
aspp=decoders.ASPP(
level=level,
num_filters=256,
pool_kernel_size=input_size[:2],
dilation_rates=aspp_dilation_rates,
use_depthwise_convolution=True,
dropout_rate=0.1)),
semantic_head=SemanticHead(
level=level,
num_convs=1,
num_filters=256,
kernel_size=5,
use_depthwise_convolution=True,
upsample_factor=1,
low_level=[3, 2],
low_level_num_filters=[64, 32],
fusion_num_output_filters=256,
prediction_kernel_size=1),
instance_head=InstanceHead(
level=level,
num_convs=1,
num_filters=32,
kernel_size=5,
use_depthwise_convolution=True,
upsample_factor=1,
low_level=[3, 2],
low_level_num_filters=[32, 16],
fusion_num_output_filters=128,
prediction_kernel_size=1),
shared_decoder=False,
generate_panoptic_masks=True,
post_processor=PanopticDeeplabPostProcessor(
output_size=input_size[:2],
center_score_threshold=0.1,
thing_class_ids=list(range(1, num_thing_categories)),
label_divisor=256,
stuff_area_limit=4096,
ignore_label=ignore_label,
nms_kernel=41,
keep_k_centers=200,
rescale_predictions=True)),
losses=Losses(
label_smoothing=0.0,
ignore_label=ignore_label,
l2_weight_decay=0.0,
top_k_percent_pixels=0.2,
segmentation_loss_weight=1.0,
center_heatmap_loss_weight=200,
center_offset_loss_weight=0.01),
train_data=DataConfig(
input_path=os.path.join(_COCO_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
parser=Parser(
aug_scale_min=0.5,
aug_scale_max=2.0,
aug_rand_hflip=True,
aug_type=common.Augmentation(
type='autoaug',
autoaug=common.AutoAugment(
augmentation_name='panoptic_deeplab_policy')),
sigma=8.0,
small_instance_area_threshold=4096,
small_instance_weight=3.0)),
validation_data=DataConfig(
input_path=os.path.join(_COCO_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=eval_batch_size,
parser=Parser(
resize_eval_groundtruth=False,
groundtruth_padded_size=[640, 640],
aug_scale_min=1.0,
aug_scale_max=1.0,
aug_rand_hflip=False,
aug_type=None,
sigma=8.0,
small_instance_area_threshold=4096,
small_instance_weight=3.0),
drop_remainder=False),
evaluation=Evaluation(
ignored_label=ignore_label,
max_instances_per_category=256,
offset=256*256*256,
is_thing=is_thing,
rescale_predictions=True,
report_per_class_pq=False,
report_per_class_iou=False,
report_train_mean_iou=False)),
trainer=cfg.TrainerConfig(
train_steps=train_steps,
validation_steps=validation_steps,
validation_interval=steps_per_epoch,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adam',
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 0.001,
'decay_steps': train_steps,
'end_learning_rate': 0.0,
'power': 0.9
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 2000,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('panoptic_deeplab_mobilenetv3_small_coco')
def panoptic_deeplab_mobilenetv3_small_coco() -> cfg.ExperimentConfig:
"""COCO panoptic segmentation with Panoptic Deeplab."""
train_steps = 200000
train_batch_size = 64
eval_batch_size = 1
steps_per_epoch = _COCO_TRAIN_EXAMPLES // train_batch_size
validation_steps = _COCO_VAL_EXAMPLES // eval_batch_size
num_panoptic_categories = 201
num_thing_categories = 91
ignore_label = 0
is_thing = [False]
for idx in range(1, num_panoptic_categories):
is_thing.append(True if idx <= num_thing_categories else False)
input_size = [640, 640, 3]
output_stride = 16
aspp_dilation_rates = [6, 12, 18]
level = int(np.math.log2(output_stride))
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(
mixed_precision_dtype='float32', enable_xla=True),
task=PanopticDeeplabTask(
init_checkpoint='gs://tf_model_garden/vision/panoptic/panoptic_deeplab/imagenet/mobilenetv3_small/ckpt-312000',
init_checkpoint_modules=['backbone'],
model=PanopticDeeplab(
num_classes=num_panoptic_categories,
input_size=input_size,
backbone=backbones.Backbone(
type='mobilenet', mobilenet=backbones.MobileNet(
model_id='MobileNetV3Small',
filter_size_scale=1.0,
stochastic_depth_drop_rate=0.0,
output_stride=output_stride)),
decoder=decoders.Decoder(
type='aspp',
aspp=decoders.ASPP(
level=level,
num_filters=256,
pool_kernel_size=input_size[:2],
dilation_rates=aspp_dilation_rates,
use_depthwise_convolution=True,
dropout_rate=0.1)),
semantic_head=SemanticHead(
level=level,
num_convs=1,
num_filters=256,
kernel_size=5,
use_depthwise_convolution=True,
upsample_factor=1,
low_level=[3, 2],
low_level_num_filters=[64, 32],
fusion_num_output_filters=256,
prediction_kernel_size=1),
instance_head=InstanceHead(
level=level,
num_convs=1,
num_filters=32,
kernel_size=5,
use_depthwise_convolution=True,
upsample_factor=1,
low_level=[3, 2],
low_level_num_filters=[32, 16],
fusion_num_output_filters=128,
prediction_kernel_size=1),
shared_decoder=False,
generate_panoptic_masks=True,
post_processor=PanopticDeeplabPostProcessor(
output_size=input_size[:2],
center_score_threshold=0.1,
thing_class_ids=list(range(1, num_thing_categories)),
label_divisor=256,
stuff_area_limit=4096,
ignore_label=ignore_label,
nms_kernel=41,
keep_k_centers=200,
rescale_predictions=True)),
losses=Losses(
label_smoothing=0.0,
ignore_label=ignore_label,
l2_weight_decay=0.0,
top_k_percent_pixels=0.2,
segmentation_loss_weight=1.0,
center_heatmap_loss_weight=200,
center_offset_loss_weight=0.01),
train_data=DataConfig(
input_path=os.path.join(_COCO_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
parser=Parser(
aug_scale_min=0.5,
aug_scale_max=2.0,
aug_rand_hflip=True,
aug_type=common.Augmentation(
type='autoaug',
autoaug=common.AutoAugment(
augmentation_name='panoptic_deeplab_policy')),
sigma=8.0,
small_instance_area_threshold=4096,
small_instance_weight=3.0)),
validation_data=DataConfig(
input_path=os.path.join(_COCO_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=eval_batch_size,
parser=Parser(
resize_eval_groundtruth=False,
groundtruth_padded_size=[640, 640],
aug_scale_min=1.0,
aug_scale_max=1.0,
aug_rand_hflip=False,
aug_type=None,
sigma=8.0,
small_instance_area_threshold=4096,
small_instance_weight=3.0),
drop_remainder=False),
evaluation=Evaluation(
ignored_label=ignore_label,
max_instances_per_category=256,
offset=256*256*256,
is_thing=is_thing,
rescale_predictions=True,
report_per_class_pq=False,
report_per_class_iou=False,
report_train_mean_iou=False)),
trainer=cfg.TrainerConfig(
train_steps=train_steps,
validation_steps=validation_steps,
validation_interval=steps_per_epoch,
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adam',
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 0.001,
'decay_steps': train_steps,
'end_learning_rate': 0.0,
'power': 0.9
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 2000,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment