Commit 472e2f80 authored by zhanggzh's avatar zhanggzh
Browse files

Merge remote-tracking branch 'tf_model/main'

parents d91296eb f3a14f85
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "Bp8t2AI8i7uP"
},
"source": [
"##### Copyright 2022 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "rxPj2Lsni9O4"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6xS-9i5DrRvO"
},
"source": [
"# Customizing a Transformer Encoder"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Mwb9uw1cDXsa"
},
"source": [
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/nlp/customize_encoder\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/nlp/customize_encoder.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/nlp/customize_encoder.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
" </td>\n",
" <td>\n",
" <a href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/nlp/customize_encoder.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
" </td>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iLrcV4IyrcGX"
},
"source": [
"## Learning objectives\n",
"\n",
"The [TensorFlow Models NLP library](https://github.com/tensorflow/models/tree/master/official/nlp/modeling) is a collection of tools for building and training modern high performance natural language models.\n",
"\n",
"The `tfm.nlp.networks.EncoderScaffold` is the core of this library, and lots of new network architectures are proposed to improve the encoder. In this Colab notebook, we will learn how to customize the encoder to employ new network architectures."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YYxdyoWgsl8t"
},
"source": [
"## Install and import"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fEJSFutUsn_h"
},
"source": [
"### Install the TensorFlow Model Garden pip package\n",
"\n",
"* `tf-models-official` is the stable Model Garden package. Note that it may not include the latest changes in the `tensorflow_models` github repo. To include latest changes, you may install `tf-models-nightly`,\n",
"which is the nightly Model Garden package created daily automatically.\n",
"* `pip` will install all models and dependencies automatically."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mfHI5JyuJ1y9"
},
"outputs": [],
"source": [
"!pip install -q opencv-python"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "thsKZDjhswhR"
},
"outputs": [],
"source": [
"!pip install -q tf-models-official"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hpf7JPCVsqtv"
},
"source": [
"### Import Tensorflow and other libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "my4dp-RMssQe"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"\n",
"import tensorflow_models as tfm\n",
"nlp = tfm.nlp"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vjDmVsFfs85n"
},
"source": [
"## Canonical BERT encoder\n",
"\n",
"Before learning how to customize the encoder, let's firstly create a canonical BERT enoder and use it to instantiate a `bert_classifier.BertClassifier` for classification task."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Oav8sbgstWc-"
},
"outputs": [],
"source": [
"cfg = {\n",
" \"vocab_size\": 100,\n",
" \"hidden_size\": 32,\n",
" \"num_layers\": 3,\n",
" \"num_attention_heads\": 4,\n",
" \"intermediate_size\": 64,\n",
" \"activation\": tfm.utils.activations.gelu,\n",
" \"dropout_rate\": 0.1,\n",
" \"attention_dropout_rate\": 0.1,\n",
" \"max_sequence_length\": 16,\n",
" \"type_vocab_size\": 2,\n",
" \"initializer\": tf.keras.initializers.TruncatedNormal(stddev=0.02),\n",
"}\n",
"bert_encoder = nlp.networks.BertEncoder(**cfg)\n",
"\n",
"def build_classifier(bert_encoder):\n",
" return nlp.models.BertClassifier(bert_encoder, num_classes=2)\n",
"\n",
"canonical_classifier_model = build_classifier(bert_encoder)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Qe2UWI6_tsHo"
},
"source": [
"`canonical_classifier_model` can be trained using the training data. For details about how to train the model, please see the [Fine tuning bert](https://www.tensorflow.org/text/tutorials/fine_tune_bert) notebook. We skip the code that trains the model here.\n",
"\n",
"After training, we can apply the model to do prediction.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "csED2d-Yt5h6"
},
"outputs": [],
"source": [
"def predict(model):\n",
" batch_size = 3\n",
" np.random.seed(0)\n",
" word_ids = np.random.randint(\n",
" cfg[\"vocab_size\"], size=(batch_size, cfg[\"max_sequence_length\"]))\n",
" mask = np.random.randint(2, size=(batch_size, cfg[\"max_sequence_length\"]))\n",
" type_ids = np.random.randint(\n",
" cfg[\"type_vocab_size\"], size=(batch_size, cfg[\"max_sequence_length\"]))\n",
" print(model([word_ids, mask, type_ids], training=False))\n",
"\n",
"predict(canonical_classifier_model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PzKStEK9t_Pb"
},
"source": [
"## Customize BERT encoder\n",
"\n",
"One BERT encoder consists of an embedding network and multiple transformer blocks, and each transformer block contains an attention layer and a feedforward layer."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rmwQfhj6fmKz"
},
"source": [
"We provide easy ways to customize each of those components via (1)\n",
"[EncoderScaffold](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/encoder_scaffold.py) and (2) [TransformerScaffold](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/transformer_scaffold.py)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xsMgEVHAui11"
},
"source": [
"### Use EncoderScaffold\n",
"\n",
"`networks.EncoderScaffold` allows users to provide a custom embedding subnetwork\n",
" (which will replace the standard embedding logic) and/or a custom hidden layer class (which will replace the `Transformer` instantiation in the encoder)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-JBabpa2AOz8"
},
"source": [
"#### Without Customization\n",
"\n",
"Without any customization, `networks.EncoderScaffold` behaves the same the canonical `networks.BertEncoder`.\n",
"\n",
"As shown in the following example, `networks.EncoderScaffold` can load `networks.BertEncoder`'s weights and output the same values:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ktNzKuVByZQf"
},
"outputs": [],
"source": [
"default_hidden_cfg = dict(\n",
" num_attention_heads=cfg[\"num_attention_heads\"],\n",
" intermediate_size=cfg[\"intermediate_size\"],\n",
" intermediate_activation=cfg[\"activation\"],\n",
" dropout_rate=cfg[\"dropout_rate\"],\n",
" attention_dropout_rate=cfg[\"attention_dropout_rate\"],\n",
" kernel_initializer=cfg[\"initializer\"],\n",
")\n",
"default_embedding_cfg = dict(\n",
" vocab_size=cfg[\"vocab_size\"],\n",
" type_vocab_size=cfg[\"type_vocab_size\"],\n",
" hidden_size=cfg[\"hidden_size\"],\n",
" initializer=cfg[\"initializer\"],\n",
" dropout_rate=cfg[\"dropout_rate\"],\n",
" max_seq_length=cfg[\"max_sequence_length\"]\n",
")\n",
"default_kwargs = dict(\n",
" hidden_cfg=default_hidden_cfg,\n",
" embedding_cfg=default_embedding_cfg,\n",
" num_hidden_instances=cfg[\"num_layers\"],\n",
" pooled_output_dim=cfg[\"hidden_size\"],\n",
" return_all_layer_outputs=True,\n",
" pooler_layer_initializer=cfg[\"initializer\"],\n",
")\n",
"\n",
"encoder_scaffold = nlp.networks.EncoderScaffold(**default_kwargs)\n",
"classifier_model_from_encoder_scaffold = build_classifier(encoder_scaffold)\n",
"classifier_model_from_encoder_scaffold.set_weights(\n",
" canonical_classifier_model.get_weights())\n",
"predict(classifier_model_from_encoder_scaffold)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sMaUmLyIuwcs"
},
"source": [
"#### Customize Embedding\n",
"\n",
"Next, we show how to use a customized embedding network.\n",
"\n",
"We first build an embedding network that would replace the default network. This one will have 2 inputs (`mask` and `word_ids`) instead of 3, and won't use positional embeddings."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LTinnaG6vcsw"
},
"outputs": [],
"source": [
"word_ids = tf.keras.layers.Input(\n",
" shape=(cfg['max_sequence_length'],), dtype=tf.int32, name=\"input_word_ids\")\n",
"mask = tf.keras.layers.Input(\n",
" shape=(cfg['max_sequence_length'],), dtype=tf.int32, name=\"input_mask\")\n",
"embedding_layer = nlp.layers.OnDeviceEmbedding(\n",
" vocab_size=cfg['vocab_size'],\n",
" embedding_width=cfg['hidden_size'],\n",
" initializer=cfg[\"initializer\"],\n",
" name=\"word_embeddings\")\n",
"word_embeddings = embedding_layer(word_ids)\n",
"attention_mask = nlp.layers.SelfAttentionMask()([word_embeddings, mask])\n",
"new_embedding_network = tf.keras.Model([word_ids, mask],\n",
" [word_embeddings, attention_mask])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HN7_yu-6O3qI"
},
"source": [
"Inspecting `new_embedding_network`, we can see it takes two inputs:\n",
"`input_word_ids` and `input_mask`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fO9zKFE4OpHp"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(new_embedding_network, show_shapes=True, dpi=48)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9cOaGQHLv12W"
},
"source": [
"We can then build a new encoder using the above `new_embedding_network`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mtFDMNf2vIl9"
},
"outputs": [],
"source": [
"kwargs = dict(default_kwargs)\n",
"\n",
"# Use new embedding network.\n",
"kwargs['embedding_cls'] = new_embedding_network\n",
"kwargs['embedding_data'] = embedding_layer.embeddings\n",
"\n",
"encoder_with_customized_embedding = nlp.networks.EncoderScaffold(**kwargs)\n",
"classifier_model = build_classifier(encoder_with_customized_embedding)\n",
"# ... Train the model ...\n",
"print(classifier_model.inputs)\n",
"\n",
"# Assert that there are only two inputs.\n",
"assert len(classifier_model.inputs) == 2"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Z73ZQDtmwg9K"
},
"source": [
"#### Customized Transformer\n",
"\n",
"Users can also override the `hidden_cls` argument in `networks.EncoderScaffold`'s constructor employ a customized Transformer layer.\n",
"\n",
"See [the source of `nlp.layers.ReZeroTransformer`](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/rezero_transformer.py) for how to implement a customized Transformer layer.\n",
"\n",
"The following is an example of using `nlp.layers.ReZeroTransformer`:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uAIarLZgw6pA"
},
"outputs": [],
"source": [
"kwargs = dict(default_kwargs)\n",
"\n",
"# Use ReZeroTransformer.\n",
"kwargs['hidden_cls'] = nlp.layers.ReZeroTransformer\n",
"\n",
"encoder_with_rezero_transformer = nlp.networks.EncoderScaffold(**kwargs)\n",
"classifier_model = build_classifier(encoder_with_rezero_transformer)\n",
"# ... Train the model ...\n",
"predict(classifier_model)\n",
"\n",
"# Assert that the variable `rezero_alpha` from ReZeroTransformer exists.\n",
"assert 'rezero_alpha' in ''.join([x.name for x in classifier_model.trainable_weights])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6PMHFdvnxvR0"
},
"source": [
"### Use `nlp.layers.TransformerScaffold`\n",
"\n",
"The above method of customizing the model requires rewriting the whole `nlp.layers.Transformer` layer, while sometimes you may only want to customize either attention layer or feedforward block. In this case, `nlp.layers.TransformerScaffold` can be used.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "D6FejlgwyAy_"
},
"source": [
"#### Customize Attention Layer\n",
"\n",
"User can also override the `attention_cls` argument in `layers.TransformerScaffold`'s constructor to employ a customized Attention layer.\n",
"\n",
"See [the source of `nlp.layers.TalkingHeadsAttention`](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/talking_heads_attention.py) for how to implement a customized `Attention` layer.\n",
"\n",
"Following is an example of using `nlp.layers.TalkingHeadsAttention`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nFrSMrZuyNeQ"
},
"outputs": [],
"source": [
"# Use TalkingHeadsAttention\n",
"hidden_cfg = dict(default_hidden_cfg)\n",
"hidden_cfg['attention_cls'] = nlp.layers.TalkingHeadsAttention\n",
"\n",
"kwargs = dict(default_kwargs)\n",
"kwargs['hidden_cls'] = nlp.layers.TransformerScaffold\n",
"kwargs['hidden_cfg'] = hidden_cfg\n",
"\n",
"encoder = nlp.networks.EncoderScaffold(**kwargs)\n",
"classifier_model = build_classifier(encoder)\n",
"# ... Train the model ...\n",
"predict(classifier_model)\n",
"\n",
"# Assert that the variable `pre_softmax_weight` from TalkingHeadsAttention exists.\n",
"assert 'pre_softmax_weight' in ''.join([x.name for x in classifier_model.trainable_weights])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tKkZ8spzYmpc"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(encoder_with_rezero_transformer, show_shapes=True, dpi=48)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kuEJcTyByVvI"
},
"source": [
"#### Customize Feedforward Layer\n",
"\n",
"Similiarly, one could also customize the feedforward layer.\n",
"\n",
"See [the source of `nlp.layers.GatedFeedforward`](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/gated_feedforward.py) for how to implement a customized feedforward layer.\n",
"\n",
"Following is an example of using `nlp.layers.GatedFeedforward`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XAbKy_l4y_-i"
},
"outputs": [],
"source": [
"# Use GatedFeedforward\n",
"hidden_cfg = dict(default_hidden_cfg)\n",
"hidden_cfg['feedforward_cls'] = nlp.layers.GatedFeedforward\n",
"\n",
"kwargs = dict(default_kwargs)\n",
"kwargs['hidden_cls'] = nlp.layers.TransformerScaffold\n",
"kwargs['hidden_cfg'] = hidden_cfg\n",
"\n",
"encoder_with_gated_feedforward = nlp.networks.EncoderScaffold(**kwargs)\n",
"classifier_model = build_classifier(encoder_with_gated_feedforward)\n",
"# ... Train the model ...\n",
"predict(classifier_model)\n",
"\n",
"# Assert that the variable `gate` from GatedFeedforward exists.\n",
"assert 'gate' in ''.join([x.name for x in classifier_model.trainable_weights])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a_8NWUhkzeAq"
},
"source": [
"### Build a new Encoder\n",
"\n",
"Finally, you could also build a new encoder using building blocks in the modeling library.\n",
"\n",
"See [the source for `nlp.networks.AlbertEncoder`](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/albert_encoder.py) as an example of how to do this. \n",
"\n",
"Here is an example using `nlp.networks.AlbertEncoder`:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xsiA3RzUzmUM"
},
"outputs": [],
"source": [
"albert_encoder = nlp.networks.AlbertEncoder(**cfg)\n",
"classifier_model = build_classifier(albert_encoder)\n",
"# ... Train the model ...\n",
"predict(classifier_model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MeidDfhlHKSO"
},
"source": [
"Inspecting the `albert_encoder`, we see it stacks the same `Transformer` layer multiple times (note the loop-back on the \"Transformer\" block below.."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Uv_juT22HERW"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(albert_encoder, show_shapes=True, dpi=48)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "customize_encoder.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "vXLA5InzXydn"
},
"source": [
"##### Copyright 2021 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "RuRlpLL-X0R_"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2X-XaMSVcLua"
},
"source": [
"# Decoding API"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hYEwGTeCXnnX"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/nlp/decoding_api\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/nlp/decoding_api.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/nlp/decoding_api.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/nlp/decoding_api.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fsACVQpVSifi"
},
"source": [
"### Install the TensorFlow Model Garden pip package\n",
"\n",
"* `tf-models-official` is the stable Model Garden package. Note that it may not include the latest changes in the `tensorflow_models` github repo. To include latest changes, you may install `tf-models-nightly`,\n",
"which is the nightly Model Garden package created daily automatically.\n",
"* pip will install all models and dependencies automatically."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "G4BhAu01HZcM"
},
"outputs": [],
"source": [
"!pip uninstall -y opencv-python"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2j-xhrsVQOQT"
},
"outputs": [],
"source": [
"!pip install tf-models-official"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BjP7zwxmskpY"
},
"outputs": [],
"source": [
"import os\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import tensorflow as tf\n",
"\n",
"from tensorflow_models import nlp"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "T92ccAzlnGqh"
},
"outputs": [],
"source": [
"def length_norm(length, dtype):\n",
" \"\"\"Return length normalization factor.\"\"\"\n",
" return tf.pow(((5. + tf.cast(length, dtype)) / 6.), 0.0)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0AWgyo-IQ5sP"
},
"source": [
"## Overview\n",
"\n",
"This API provides an interface to experiment with different decoding strategies used for auto-regressive models.\n",
"\n",
"1. The following sampling strategies are provided in sampling_module.py, which inherits from the base Decoding class:\n",
" * [top_p](https://arxiv.org/abs/1904.09751) : [github](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/ops/sampling_module.py#L65) \n",
"\n",
" This implementation chooses the most probable logits with cumulative probabilities up to top_p.\n",
"\n",
" * [top_k](https://arxiv.org/pdf/1805.04833.pdf) : [github](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/ops/sampling_module.py#L48)\n",
"\n",
" At each timestep, this implementation samples from top-k logits based on their probability distribution\n",
"\n",
" * Greedy : [github](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/ops/sampling_module.py#L26)\n",
"\n",
" This implementation returns the top logits based on probabilities.\n",
"\n",
"2. Beam search is provided in beam_search.py. [github](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/ops/beam_search.py)\n",
"\n",
" This implementation reduces the risk of missing hidden high probability logits by keeping the most likely num_beams of logits at each time step and eventually choosing the logits that has the overall highest probability."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MfOj7oaBRQnS"
},
"source": [
"## Initialize Sampling Module in TF-NLP.\n",
"\n",
"\n",
"\u003e **symbols_to_logits_fn** : This is a closure implemented by the users of the API. The input to this closure will be \n",
"```\n",
"Args:\n",
" 1] ids [batch_size, .. (index + 1 or 1 if padded_decode is True)],\n",
" 2] index [scalar] : current decoded step,\n",
" 3] cache [nested dictionary of tensors].\n",
"Returns:\n",
" 1] tensor for next-step logits [batch_size, vocab]\n",
" 2] the updated_cache [nested dictionary of tensors].\n",
"```\n",
"This closure calls the model to predict the logits for the 'index+1' step. The cache is used for faster decoding.\n",
"Here is a [reference](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/ops/beam_search_test.py#L88) implementation for the above closure.\n",
"\n",
"\n",
"\u003e **length_normalization_fn** : Closure for returning length normalization parameter.\n",
"```\n",
"Args: \n",
" 1] length : scalar for decoded step index.\n",
" 2] dtype : data-type of output tensor\n",
"Returns:\n",
" 1] value of length normalization factor.\n",
"Example :\n",
" def _length_norm(length, dtype):\n",
" return tf.pow(((5. + tf.cast(length, dtype)) / 6.), 0.0)\n",
"```\n",
"\n",
"\u003e **vocab_size** : Output vocabulary size.\n",
"\n",
"\u003e **max_decode_length** : Scalar for total number of decoding steps.\n",
"\n",
"\u003e **eos_id** : Decoding will stop if all output decoded ids in the batch have this ID.\n",
"\n",
"\u003e **padded_decode** : Set this to True if running on TPU. Tensors are padded to max_decoding_length if this is True.\n",
"\n",
"\u003e **top_k** : top_k is enabled if this value is \u003e 1.\n",
"\n",
"\u003e **top_p** : top_p is enabled if this value is \u003e 0 and \u003c 1.0\n",
"\n",
"\u003e **sampling_temperature** : This is used to re-estimate the softmax output. Temperature skews the distribution towards high-probability tokens and lowers the mass in the tail distribution. Value has to be positive. Low temperature is equivalent to greedy and makes the distribution sharper, while high temperature makes it flatter.\n",
"\n",
"\u003e **enable_greedy** : By default, this is true and greedy decoding is enabled.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lV1RRp6ihnGX"
},
"source": [
"## Initialize the Model Hyper-parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eTsGp2gaKLdE"
},
"outputs": [],
"source": [
"params = {\n",
" 'num_heads': 2,\n",
" 'num_layers': 2,\n",
" 'batch_size': 2,\n",
" 'n_dims': 256,\n",
" 'max_decode_length': 4}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CYXkoplAij01"
},
"source": [
"## Initialize cache. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UGvmd0_dRFYI"
},
"source": [
"In auto-regressive architectures like Transformer based [Encoder-Decoder](https://arxiv.org/abs/1706.03762) models, \n",
"Cache is used for fast sequential decoding.\n",
"It is a nested dictionary storing pre-computed hidden-states (key and values in the self-attention blocks and the cross-attention blocks) for every layer."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "D6kfZOOKgkm1"
},
"outputs": [],
"source": [
"cache = {\n",
" 'layer_%d' % layer: {\n",
" 'k': tf.zeros(\n",
" shape=[params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims'] // params['num_heads']],\n",
" dtype=tf.float32),\n",
" 'v': tf.zeros(\n",
" shape=[params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims'] // params['num_heads']],\n",
" dtype=tf.float32)\n",
" } for layer in range(params['num_layers'])\n",
" }\n",
"print(\"cache value shape for layer 1 :\", cache['layer_1']['k'].shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "syl7I5nURPgW"
},
"source": [
"### Create model_fn\n",
" In practice, this will be replaced by an actual model implementation such as [here](https://github.com/tensorflow/models/blob/master/official/nlp/transformer/transformer.py#L236)\n",
"```\n",
"Args:\n",
"i : Step that is being decoded.\n",
"Returns:\n",
" logit probabilities of size [batch_size, 1, vocab_size]\n",
"```\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AhzSkRisRdB6"
},
"outputs": [],
"source": [
"probabilities = tf.constant([[[0.3, 0.4, 0.3], [0.3, 0.3, 0.4],\n",
" [0.1, 0.1, 0.8], [0.1, 0.1, 0.8]],\n",
" [[0.2, 0.5, 0.3], [0.2, 0.7, 0.1],\n",
" [0.1, 0.1, 0.8], [0.1, 0.1, 0.8]]])\n",
"def model_fn(i):\n",
" return probabilities[:, i, :]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FAJ4CpbfVdjr"
},
"outputs": [],
"source": [
"def _symbols_to_logits_fn():\n",
" \"\"\"Calculates logits of the next tokens.\"\"\"\n",
" def symbols_to_logits_fn(ids, i, temp_cache):\n",
" del ids\n",
" logits = tf.cast(tf.math.log(model_fn(i)), tf.float32)\n",
" return logits, temp_cache\n",
" return symbols_to_logits_fn"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R_tV3jyWVL47"
},
"source": [
"## Greedy \n",
"Greedy decoding selects the token id with the highest probability as its next id: $id_t = argmax_{w}P(id | id_{1:t-1})$ at each timestep $t$. The following sketch shows greedy decoding. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aGt9idSkVQEJ"
},
"outputs": [],
"source": [
"greedy_obj = sampling_module.SamplingModule(\n",
" length_normalization_fn=None,\n",
" dtype=tf.float32,\n",
" symbols_to_logits_fn=_symbols_to_logits_fn(),\n",
" vocab_size=3,\n",
" max_decode_length=params['max_decode_length'],\n",
" eos_id=10,\n",
" padded_decode=False)\n",
"ids, _ = greedy_obj.generate(\n",
" initial_ids=tf.constant([9, 1]), initial_cache=cache)\n",
"print(\"Greedy Decoded Ids:\", ids)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s4pTTsQXVz5O"
},
"source": [
"## top_k sampling\n",
"In *Top-K* sampling, the *K* most likely next token ids are filtered and the probability mass is redistributed among only those *K* ids. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pCLWIn6GV5_G"
},
"outputs": [],
"source": [
"top_k_obj = sampling_module.SamplingModule(\n",
" length_normalization_fn=length_norm,\n",
" dtype=tf.float32,\n",
" symbols_to_logits_fn=_symbols_to_logits_fn(),\n",
" vocab_size=3,\n",
" max_decode_length=params['max_decode_length'],\n",
" eos_id=10,\n",
" sample_temperature=tf.constant(1.0),\n",
" top_k=tf.constant(3),\n",
" padded_decode=False,\n",
" enable_greedy=False)\n",
"ids, _ = top_k_obj.generate(\n",
" initial_ids=tf.constant([9, 1]), initial_cache=cache)\n",
"print(\"top-k sampled Ids:\", ids)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Jp3G-eE_WI4Y"
},
"source": [
"## top_p sampling\n",
"Instead of sampling only from the most likely *K* token ids, in *Top-p* sampling chooses from the smallest possible set of ids whose cumulative probability exceeds the probability *p*."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rEGdIWcuWILO"
},
"outputs": [],
"source": [
"top_p_obj = sampling_module.SamplingModule(\n",
" length_normalization_fn=length_norm,\n",
" dtype=tf.float32,\n",
" symbols_to_logits_fn=_symbols_to_logits_fn(),\n",
" vocab_size=3,\n",
" max_decode_length=params['max_decode_length'],\n",
" eos_id=10,\n",
" sample_temperature=tf.constant(1.0),\n",
" top_p=tf.constant(0.9),\n",
" padded_decode=False,\n",
" enable_greedy=False)\n",
"ids, _ = top_p_obj.generate(\n",
" initial_ids=tf.constant([9, 1]), initial_cache=cache)\n",
"print(\"top-p sampled Ids:\", ids)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2hcuyJ2VWjDz"
},
"source": [
"## Beam search decoding\n",
"Beam search reduces the risk of missing hidden high probability token ids by keeping the most likely num_beams of hypotheses at each time step and eventually choosing the hypothesis that has the overall highest probability. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cJ3WzvSrWmSA"
},
"outputs": [],
"source": [
"beam_size = 2\n",
"params['batch_size'] = 1\n",
"beam_cache = {\n",
" 'layer_%d' % layer: {\n",
" 'k': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims']], dtype=tf.float32),\n",
" 'v': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims']], dtype=tf.float32)\n",
" } for layer in range(params['num_layers'])\n",
" }\n",
"print(\"cache key shape for layer 1 :\", beam_cache['layer_1']['k'].shape)\n",
"ids, _ = beam_search.sequence_beam_search(\n",
" symbols_to_logits_fn=_symbols_to_logits_fn(),\n",
" initial_ids=tf.constant([9], tf.int32),\n",
" initial_cache=beam_cache,\n",
" vocab_size=3,\n",
" beam_size=beam_size,\n",
" alpha=0.6,\n",
" max_decode_length=params['max_decode_length'],\n",
" eos_id=10,\n",
" padded_decode=False,\n",
" dtype=tf.float32)\n",
"print(\"Beam search ids:\", ids)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "decoding_api_in_tf_nlp.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "vXLA5InzXydn"
},
"source": [
"##### Copyright 2019 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "RuRlpLL-X0R_"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1mLJmVotXs64"
},
"source": [
"# Fine-tuning a BERT model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hYEwGTeCXnnX"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/nlp/fine_tune_bert\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/nlp/fine_tune_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/nlp/fine_tune_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/nlp/fine_tune_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://tfhub.dev/google/collections/bert\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/hub_logo_32px.png\" /\u003eSee TF Hub model\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YN2ACivEPxgD"
},
"source": [
"This tutorial demonstrates how to fine-tune a [Bidirectional Encoder Representations from Transformers (BERT)](https://arxiv.org/abs/1810.04805) (Devlin et al., 2018) model using [TensorFlow Model Garden](https://github.com/tensorflow/models).\n",
"\n",
"You can also find the pre-trained BERT model used in this tutorial on [TensorFlow Hub (TF Hub)](https://tensorflow.org/hub). For concrete examples of how to use the models from TF Hub, refer to the [Solve Glue tasks using BERT](https://www.tensorflow.org/text/tutorials/bert_glue) tutorial. If you're just trying to fine-tune a model, the TF Hub tutorial is a good starting point.\n",
"\n",
"On the other hand, if you're interested in deeper customization, follow this tutorial. It shows how to do a lot of things manually, so you can learn how you can customize the workflow from data preprocessing to training, exporting and saving the model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s2d9S2CSSO1z"
},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "69de3375e32a"
},
"source": [
"### Install pip packages"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fsACVQpVSifi"
},
"source": [
"Start by installing the TensorFlow Text and Model Garden pip packages.\n",
"\n",
"* `tf-models-official` is the TensorFlow Model Garden package. Note that it may not include the latest changes in the `tensorflow_models` GitHub repo. To include the latest changes, you may install `tf-models-nightly`, which is the nightly Model Garden package created daily automatically.\n",
"* pip will install all models and dependencies automatically."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sE6XUxLOf1s-"
},
"outputs": [],
"source": [
"!pip install -q opencv-python"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yic2y7_o-BCC"
},
"outputs": [],
"source": [
"!pip install -q -U \"tensorflow-text==2.11.*\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NvNr2svBM-p3"
},
"outputs": [],
"source": [
"!pip install -q tf-models-official"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "U-7qPCjWUAyy"
},
"source": [
"### Import libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lXsXev5MNr20"
},
"outputs": [],
"source": [
"import os\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import tensorflow as tf\n",
"import tensorflow_models as tfm\n",
"import tensorflow_hub as hub\n",
"import tensorflow_datasets as tfds\n",
"tfds.disable_progress_bar()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mbanlzTvJBsz"
},
"source": [
"### Resources"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PpW0x8TpR8DT"
},
"source": [
"The following directory contains the BERT model's configuration, vocabulary, and a pre-trained checkpoint used in this tutorial:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vzRHOLciR8eq"
},
"outputs": [],
"source": [
"gs_folder_bert = \"gs://cloud-tpu-checkpoints/bert/v3/uncased_L-12_H-768_A-12\"\n",
"tf.io.gfile.listdir(gs_folder_bert)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Qv6abtRvH4xO"
},
"source": [
"## Load and preprocess the dataset\n",
"\n",
"This example uses the GLUE (General Language Understanding Evaluation) MRPC (Microsoft Research Paraphrase Corpus) [dataset from TensorFlow Datasets (TFDS)](https://www.tensorflow.org/datasets/catalog/glue#gluemrpc).\n",
"\n",
"This dataset is not set up such that it can be directly fed into the BERT model. The following section handles the necessary preprocessing."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "28DvUhC1YUiB"
},
"source": [
"### Get the dataset from TensorFlow Datasets\n",
"\n",
"The GLUE MRPC (Dolan and Brockett, 2005) dataset is a corpus of sentence pairs automatically extracted from online news sources, with human annotations for whether the sentences in the pair are semantically equivalent. It has the following attributes:\n",
"\n",
"* Number of labels: 2\n",
"* Size of training dataset: 3668\n",
"* Size of evaluation dataset: 408\n",
"* Maximum sequence length of training and evaluation dataset: 128\n",
"\n",
"Begin by loading the MRPC dataset from TFDS:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ijikx5OsH9AT"
},
"outputs": [],
"source": [
"batch_size=32\n",
"glue, info = tfds.load('glue/mrpc',\n",
" with_info=True,\n",
" batch_size=32)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QcMTJU4N7VX-"
},
"outputs": [],
"source": [
"glue"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZgBg2r2nYT-K"
},
"source": [
"The `info` object describes the dataset and its features:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IQrHxv7W7jH5"
},
"outputs": [],
"source": [
"info.features"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vhsVWYNxazz5"
},
"source": [
"The two classes are:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "n0gfc_VTayfQ"
},
"outputs": [],
"source": [
"info.features['label'].names"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "38zJcap6xkbC"
},
"source": [
"Here is one example from the training set:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xON_i6SkwApW"
},
"outputs": [],
"source": [
"example_batch = next(iter(glue['train']))\n",
"\n",
"for key, value in example_batch.items():\n",
" print(f\"{key:9s}: {value[0].numpy()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R9vEWgKA4SxV"
},
"source": [
"### Preprocess the data\n",
"\n",
"The keys `\"sentence1\"` and `\"sentence2\"` in the GLUE MRPC dataset contain two input sentences for each example.\n",
"\n",
"Because the BERT model from the Model Garden doesn't take raw text as input, two things need to happen first:\n",
"\n",
"1. The text needs to be _tokenized_ (split into word pieces) and converted to _indices_.\n",
"2. Then, the _indices_ need to be packed into the format that the model expects."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9fbTyfJpNr7x"
},
"source": [
"#### The BERT tokenizer"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wqeN54S61ZKQ"
},
"source": [
"To fine tune a pre-trained language model from the Model Garden, such as BERT, you need to make sure that you're using exactly the same tokenization, vocabulary, and index mapping as used during training.\n",
"\n",
"The following code rebuilds the tokenizer that was used by the base model using the Model Garden's `tfm.nlp.layers.FastWordpieceBertTokenizer` layer:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-DK4q5wEBmlB"
},
"outputs": [],
"source": [
"tokenizer = tfm.nlp.layers.FastWordpieceBertTokenizer(\n",
" vocab_file=os.path.join(gs_folder_bert, \"vocab.txt\"),\n",
" lower_case=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zYHDSquU2lDU"
},
"source": [
"Let's tokenize a test sentence:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "L_OfOYPg853R"
},
"outputs": [],
"source": [
"tokens = tokenizer(tf.constant([\"Hello TensorFlow!\"]))\n",
"tokens"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MfjaaMYy5Gt8"
},
"source": [
"Learn more about the tokenization process in the [Subword tokenization](https://www.tensorflow.org/text/guide/subwords_tokenizer) and [Tokenizing with TensorFlow Text](https://www.tensorflow.org/text/guide/tokenizers) guides."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wd1b09OO5GJl"
},
"source": [
"#### Pack the inputs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "62UTWLQd9-LB"
},
"source": [
"TensorFlow Model Garden's BERT model doesn't just take the tokenized strings as input. It also expects these to be packed into a particular format. `tfm.nlp.layers.BertPackInputs` layer can handle the conversion from _a list of tokenized sentences_ to the input format expected by the Model Garden's BERT model.\n",
"\n",
"`tfm.nlp.layers.BertPackInputs` packs the two input sentences (per example in the MRCP dataset) concatenated together. This input is expected to start with a `[CLS]` \"This is a classification problem\" token, and each sentence should end with a `[SEP]` \"Separator\" token.\n",
"\n",
"Therefore, the `tfm.nlp.layers.BertPackInputs` layer's constructor takes the `tokenizer`'s special tokens as an argument. It also needs to know the indices of the tokenizer's special tokens."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5iroDlrFDRcF"
},
"outputs": [],
"source": [
"special = tokenizer.get_special_tokens_dict()\n",
"special"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "b71HarkuG92H"
},
"outputs": [],
"source": [
"max_seq_length = 128\n",
"\n",
"packer = tfm.nlp.layers.BertPackInputs(\n",
" seq_length=max_seq_length,\n",
" special_tokens_dict = tokenizer.get_special_tokens_dict())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CZlSZbYd6liN"
},
"source": [
"The `packer` takes a list of tokenized sentences as input. For example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "27dU_VkJHc9S"
},
"outputs": [],
"source": [
"sentences1 = [\"hello tensorflow\"]\n",
"tok1 = tokenizer(sentences1)\n",
"tok1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LURHmNOSHnWN"
},
"outputs": [],
"source": [
"sentences2 = [\"goodbye tensorflow\"]\n",
"tok2 = tokenizer(sentences2)\n",
"tok2"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r8bvB8gI8BqP"
},
"source": [
"Then, it returns a dictionary containing three outputs:\n",
"\n",
"- `input_word_ids`: The tokenized sentences packed together.\n",
"- `input_mask`: The mask indicating which locations are valid in the other outputs.\n",
"- `input_type_ids`: Indicating which sentence each token belongs to."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YsIDTOMJHrUQ"
},
"outputs": [],
"source": [
"packed = packer([tok1, tok2])\n",
"\n",
"for key, tensor in packed.items():\n",
" print(f\"{key:15s}: {tensor[:, :12]}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "red4tRcq74Qc"
},
"source": [
"#### Put it all together\n",
"\n",
"Combine these two parts into a `keras.layers.Layer` that can be attached to your model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9Qtz-tv-6nz6"
},
"outputs": [],
"source": [
"class BertInputProcessor(tf.keras.layers.Layer):\n",
" def __init__(self, tokenizer, packer):\n",
" super().__init__()\n",
" self.tokenizer = tokenizer\n",
" self.packer = packer\n",
"\n",
" def call(self, inputs):\n",
" tok1 = self.tokenizer(inputs['sentence1'])\n",
" tok2 = self.tokenizer(inputs['sentence2'])\n",
"\n",
" packed = self.packer([tok1, tok2])\n",
"\n",
" if 'label' in inputs:\n",
" return packed, inputs['label']\n",
" else:\n",
" return packed"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rdy9wp499btU"
},
"source": [
"But for now just apply it to the dataset using `Dataset.map`, since the dataset you loaded from TFDS is a `tf.data.Dataset` object:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qmyh76AL7VAs"
},
"outputs": [],
"source": [
"bert_inputs_processor = BertInputProcessor(tokenizer, packer)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "B8SSCtDe9MCk"
},
"outputs": [],
"source": [
"glue_train = glue['train'].map(bert_inputs_processor).prefetch(1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KXpiDosO9rkY"
},
"source": [
"Here is an example batch from the processed dataset:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ffNvDE6t9rP-"
},
"outputs": [],
"source": [
"example_inputs, example_labels = next(iter(glue_train))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5sxtTuUi-bXt"
},
"outputs": [],
"source": [
"example_inputs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wP4z_-9a-dFk"
},
"outputs": [],
"source": [
"example_labels"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jyjTdGpFhO_1"
},
"outputs": [],
"source": [
"for key, value in example_inputs.items():\n",
" print(f'{key:15s} shape: {value.shape}')\n",
"\n",
"print(f'{\"labels\":15s} shape: {example_labels.shape}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mkGHN_FK-50U"
},
"source": [
"The `input_word_ids` contain the token IDs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eGL1_ktWLcgF"
},
"outputs": [],
"source": [
"plt.pcolormesh(example_inputs['input_word_ids'])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ulNZ4U96-8JZ"
},
"source": [
"The mask allows the model to cleanly differentiate between the content and the padding. The mask has the same shape as the `input_word_ids`, and contains a `1` anywhere the `input_word_ids` is not padding."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zB7mW7DGK3rW"
},
"outputs": [],
"source": [
"plt.pcolormesh(example_inputs['input_mask'])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rxLenwAvCkBf"
},
"source": [
"The \"input type\" also has the same shape, but inside the non-padded region, contains a `0` or a `1` indicating which sentence the token is a part of."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2CetH_5C9P2m"
},
"outputs": [],
"source": [
"plt.pcolormesh(example_inputs['input_type_ids'])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pxHHeyei_sb9"
},
"source": [
"Apply the same preprocessing to the validation and test subsets of the GLUE MRPC dataset:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yuLKxf6zHxw-"
},
"outputs": [],
"source": [
"glue_validation = glue['validation'].map(bert_inputs_processor).prefetch(1)\n",
"glue_test = glue['test'].map(bert_inputs_processor).prefetch(1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FSwymsbkbLDA"
},
"source": [
"## Build, train and export the model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bxxO3pJCEM9p"
},
"source": [
"Now that you have formatted the data as expected, you can start working on building and training the model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Efrj3Cn1kLAp"
},
"source": [
"### Build the model\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xxpOY5r2Ayq6"
},
"source": [
"The first step is to download the configuration file—`config_dict`—for the pre-trained BERT model:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "v7ap0BONSJuz"
},
"outputs": [],
"source": [
"import json\n",
"\n",
"bert_config_file = os.path.join(gs_folder_bert, \"bert_config.json\")\n",
"config_dict = json.loads(tf.io.gfile.GFile(bert_config_file).read())\n",
"config_dict"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pKaEaKJSX85J"
},
"outputs": [],
"source": [
"encoder_config = tfm.nlp.encoders.EncoderConfig({\n",
" 'type':'bert',\n",
" 'bert': config_dict\n",
"})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LbgzWukNSqOS"
},
"outputs": [],
"source": [
"bert_encoder = tfm.nlp.encoders.build_encoder(encoder_config)\n",
"bert_encoder"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "96ldxDSwkVkj"
},
"source": [
"The configuration file defines the core BERT model from the Model Garden, which is a Keras model that predicts the outputs of `num_classes` from the inputs with maximum sequence length `max_seq_length`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cH682__U0FBv"
},
"outputs": [],
"source": [
"bert_classifier = tfm.nlp.models.BertClassifier(network=bert_encoder, num_classes=2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sFmVG4SKZAw8"
},
"source": [
"Run it on a test batch of data 10 examples from the training set. The output is the logits for the two classes:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VTjgPbp4ZDKo"
},
"outputs": [],
"source": [
"bert_classifier(\n",
" example_inputs, training=True).numpy()[:10]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Q0NTdwZsQK8n"
},
"source": [
"The `TransformerEncoder` in the center of the classifier above **is** the `bert_encoder`.\n",
"\n",
"If you inspect the encoder, notice the stack of `Transformer` layers connected to those same three inputs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8L__-erBwLIQ"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(bert_encoder, show_shapes=True, dpi=48)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mKAvkQc3heSy"
},
"source": [
"### Restore the encoder weights\n",
"\n",
"When built, the encoder is randomly initialized. Restore the encoder's weights from the checkpoint:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "97Ll2Gichd_Y"
},
"outputs": [],
"source": [
"checkpoint = tf.train.Checkpoint(encoder=bert_encoder)\n",
"checkpoint.read(\n",
" os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2oHOql35k3Dd"
},
"source": [
"Note: The pretrained `TransformerEncoder` is also available on [TensorFlow Hub](https://tensorflow.org/hub). Go to the [TF Hub appendix](#hub_bert) for details."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "115caFLMk-_l"
},
"source": [
"### Set up the optimizer\n",
"\n",
"BERT typically uses the Adam optimizer with weight decay—[AdamW](https://arxiv.org/abs/1711.05101) (`tf.keras.optimizers.experimental.AdamW`).\n",
"It also employs a learning rate schedule that first warms up from 0 and then decays to 0:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "c0jBycPDtkxR"
},
"outputs": [],
"source": [
"# Set up epochs and steps\n",
"epochs = 5\n",
"batch_size = 32\n",
"eval_batch_size = 32\n",
"\n",
"train_data_size = info.splits['train'].num_examples\n",
"steps_per_epoch = int(train_data_size / batch_size)\n",
"num_train_steps = steps_per_epoch * epochs\n",
"warmup_steps = int(0.1 * num_train_steps)\n",
"initial_learning_rate=2e-5"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GFankgHK0Rvh"
},
"source": [
"Linear decay from `initial_learning_rate` to zero over `num_train_steps`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qWSyT8P2j4mV"
},
"outputs": [],
"source": [
"linear_decay = tf.keras.optimizers.schedules.PolynomialDecay(\n",
" initial_learning_rate=initial_learning_rate,\n",
" end_learning_rate=0,\n",
" decay_steps=num_train_steps)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "anZPZPAP0Y3n"
},
"source": [
"Warmup to that value over `warmup_steps`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "z_AsVCiRkoN1"
},
"outputs": [],
"source": [
"warmup_schedule = tfm.optimization.lr_schedule.LinearWarmup(\n",
" warmup_learning_rate = 0,\n",
" after_warmup_lr_sched = linear_decay,\n",
" warmup_steps = warmup_steps\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "arfbaK6t0kH_"
},
"source": [
"The overall schedule looks like this:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rYZGunhqbGUZ"
},
"outputs": [],
"source": [
"x = tf.linspace(0, num_train_steps, 1001)\n",
"y = [warmup_schedule(xi) for xi in x]\n",
"plt.plot(x,y)\n",
"plt.xlabel('Train step')\n",
"plt.ylabel('Learning rate')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bjsmG_fm0opn"
},
"source": [
"Use `tf.keras.optimizers.experimental.AdamW` to instantiate the optimizer with that schedule:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "R8pTNuKIw1dA"
},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.experimental.Adam(\n",
" learning_rate = warmup_schedule)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "78FEUOOEkoP0"
},
"source": [
"### Train the model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OTNcA0O0nSq9"
},
"source": [
"Set the metric as accuracy and the loss as sparse categorical cross-entropy. Then, compile and train the BERT classifier:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "d5FeL0b6j7ky"
},
"outputs": [],
"source": [
"metrics = [tf.keras.metrics.SparseCategoricalAccuracy('accuracy', dtype=tf.float32)]\n",
"loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
"\n",
"bert_classifier.compile(\n",
" optimizer=optimizer,\n",
" loss=loss,\n",
" metrics=metrics)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CsrylctIj_Xy"
},
"outputs": [],
"source": [
"bert_classifier.evaluate(glue_validation)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hgPPc2oNmcVZ"
},
"outputs": [],
"source": [
"bert_classifier.fit(\n",
" glue_train,\n",
" validation_data=(glue_validation),\n",
" batch_size=32,\n",
" epochs=epochs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IFtKFWbNKb0u"
},
"source": [
"Now run the fine-tuned model on a custom example to see that it works.\n",
"\n",
"Start by encoding some sentence pairs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "S1sdW6lLWaEi"
},
"outputs": [],
"source": [
"my_examples = {\n",
" 'sentence1':[\n",
" 'The rain in Spain falls mainly on the plain.',\n",
" 'Look I fine tuned BERT.'],\n",
" 'sentence2':[\n",
" 'It mostly rains on the flat lands of Spain.',\n",
" 'Is it working? This does not match.']\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7ynJibkBRTJF"
},
"source": [
"The model should report class `1` \"match\" for the first example and class `0` \"no-match\" for the second:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "umo0ttrgRYIM"
},
"outputs": [],
"source": [
"ex_packed = bert_inputs_processor(my_examples)\n",
"my_logits = bert_classifier(ex_packed, training=False)\n",
"\n",
"result_cls_ids = tf.argmax(my_logits)\n",
"result_cls_ids"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HNdmOEHKT7e8"
},
"outputs": [],
"source": [
"tf.gather(tf.constant(info.features['label'].names), result_cls_ids)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fVo_AnT0l26j"
},
"source": [
"### Export the model\n",
"\n",
"Often the goal of training a model is to _use_ it for something outside of the Python process that created it. You can do this by exporting the model using `tf.saved_model`. (Learn more in the [Using the SavedModel format](https://www.tensorflow.org/guide/saved_model) guide and the [Save and load a model using a distribution strategy](https://www.tensorflow.org/tutorials/distribute/save_and_load) tutorial.)\n",
"\n",
"First, build a wrapper class to export the model. This wrapper does two things:\n",
"\n",
"- First it packages `bert_inputs_processor` and `bert_classifier` together into a single `tf.Module`, so you can export all the functionalities.\n",
"- Second it defines a `tf.function` that implements the end-to-end execution of the model.\n",
"\n",
"Setting the `input_signature` argument of `tf.function` lets you define a fixed signature for the `tf.function`. This can be less surprising than the default automatic retracing behavior."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "78h83mlt9wpY"
},
"outputs": [],
"source": [
"class ExportModel(tf.Module):\n",
" def __init__(self, input_processor, classifier):\n",
" self.input_processor = input_processor\n",
" self.classifier = classifier\n",
"\n",
" @tf.function(input_signature=[{\n",
" 'sentence1': tf.TensorSpec(shape=[None], dtype=tf.string),\n",
" 'sentence2': tf.TensorSpec(shape=[None], dtype=tf.string)}])\n",
" def __call__(self, inputs):\n",
" packed = self.input_processor(inputs)\n",
" logits = self.classifier(packed, training=False)\n",
" result_cls_ids = tf.argmax(logits)\n",
" return {\n",
" 'logits': logits,\n",
" 'class_id': result_cls_ids,\n",
" 'class': tf.gather(\n",
" tf.constant(info.features['label'].names),\n",
" result_cls_ids)\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qnxysGUfIgFQ"
},
"source": [
"Create an instance of this export-model and save it:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TmHW9DEFUZ0X"
},
"outputs": [],
"source": [
"export_model = ExportModel(bert_inputs_processor, bert_classifier)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Nl5x6nElZqkP"
},
"outputs": [],
"source": [
"import tempfile\n",
"export_dir=tempfile.mkdtemp(suffix='_saved_model')\n",
"tf.saved_model.save(export_model, export_dir=export_dir,\n",
" signatures={'serving_default': export_model.__call__})"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Pd8B5dy-ImDJ"
},
"source": [
"Reload the model and compare the results to the original:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9cAhHySVXHD5"
},
"outputs": [],
"source": [
"original_logits = export_model(my_examples)['logits']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "H9cAcYwfW2fy"
},
"outputs": [],
"source": [
"reloaded = tf.saved_model.load(export_dir)\n",
"reloaded_logits = reloaded(my_examples)['logits']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "y_ACvKPsVUXC"
},
"outputs": [],
"source": [
"# The results are identical:\n",
"print(original_logits.numpy())\n",
"print()\n",
"print(reloaded_logits.numpy())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lBlPP20dXPFR"
},
"outputs": [],
"source": [
"print(np.mean(abs(original_logits - reloaded_logits)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CPsg7dZwfBM2"
},
"source": [
"Congratulations! You've used `tensorflow_models` to build a BERT-classifier, train it, and export for later use."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eQceYqRFT_Eg"
},
"source": [
"## Optional: BERT on TF Hub"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QbklKt-w_CiI"
},
"source": [
"\u003ca id=\"hub_bert\"\u003e\u003c/a\u003e\n",
"\n",
"\n",
"You can get the BERT model off the shelf from [TF Hub](https://tfhub.dev/). There are [many versions available along with their input preprocessors](https://tfhub.dev/google/collections/bert/1).\n",
"\n",
"This example uses [a small version of BERT from TF Hub](https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/2) that was pre-trained using the English Wikipedia and BooksCorpus datasets, similar to the [original implementation](https://arxiv.org/abs/1908.08962) (Turc et al., 2019).\n",
"\n",
"Start by importing TF Hub:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GDWrHm0BGpbX"
},
"outputs": [],
"source": [
"import tensorflow_hub as hub"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "f02f38f83ac4"
},
"source": [
"Select the input preprocessor and the model from TF Hub and wrap them as `hub.KerasLayer` layers:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lo6479At4sP1"
},
"outputs": [],
"source": [
"# Always make sure you use the right preprocessor.\n",
"hub_preprocessor = hub.KerasLayer(\n",
" \"https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3\")\n",
"\n",
"# This is a really small BERT.\n",
"hub_encoder = hub.KerasLayer(f\"https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/2\",\n",
" trainable=True)\n",
"\n",
"print(f\"The Hub encoder has {len(hub_encoder.trainable_variables)} trainable variables\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iTzF574wivQv"
},
"source": [
"Test run the preprocessor on a batch of data:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GOASSKR5R3-N"
},
"outputs": [],
"source": [
"hub_inputs = hub_preprocessor(['Hello TensorFlow!'])\n",
"{key: value[0, :10].numpy() for key, value in hub_inputs.items()} "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XEcYrCR45Uwo"
},
"outputs": [],
"source": [
"result = hub_encoder(\n",
" inputs=hub_inputs,\n",
" training=False,\n",
")\n",
"\n",
"print(\"Pooled output shape:\", result['pooled_output'].shape)\n",
"print(\"Sequence output shape:\", result['sequence_output'].shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cjojn8SmLSRI"
},
"source": [
"At this point it would be simple to add a classification head yourself.\n",
"\n",
"The Model Garden `tfm.nlp.models.BertClassifier` class can also build a classifier onto the TF Hub encoder:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9nTDaApyLR70"
},
"outputs": [],
"source": [
"hub_classifier = tfm.nlp.models.BertClassifier(\n",
" bert_encoder,\n",
" num_classes=2,\n",
" dropout_rate=0.1,\n",
" initializer=tf.keras.initializers.TruncatedNormal(\n",
" stddev=0.02))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xMJX3wV0_v7I"
},
"source": [
"The one downside to loading this model from TF Hub is that the structure of internal Keras layers is not restored. This makes it more difficult to inspect or modify the model.\n",
"\n",
"The BERT encoder model—`hub_classifier`—is now a single layer."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "u_IqwXjRV1vd"
},
"source": [
"For concrete examples of this approach, refer to [Solve Glue tasks using BERT](https://www.tensorflow.org/text/tutorials/bert_glue)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ji3tdLz101km"
},
"source": [
"## Optional: Optimizer `config`s\n",
"\n",
"The `tensorflow_models` package defines serializable `config` classes that describe how to build the live objects. Earlier in this tutorial, you built the optimizer manually.\n",
"\n",
"The configuration below describes an (almost) identical optimizer built by the `optimizer_factory.OptimizerFactory`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Fdb9C1ontnH_"
},
"outputs": [],
"source": [
"optimization_config = tfm.optimization.OptimizationConfig(\n",
" optimizer=tfm.optimization.OptimizerConfig(\n",
" type = \"adam\"),\n",
" learning_rate = tfm.optimization.LrConfig(\n",
" type='polynomial',\n",
" polynomial=tfm.optimization.PolynomialLrConfig(\n",
" initial_learning_rate=2e-5,\n",
" end_learning_rate=0.0,\n",
" decay_steps=num_train_steps)),\n",
" warmup = tfm.optimization.WarmupConfig(\n",
" type='linear',\n",
" linear=tfm.optimization.LinearWarmupConfig(warmup_steps=warmup_steps)\n",
" ))\n",
"\n",
"\n",
"fac = tfm.optimization.optimizer_factory.OptimizerFactory(optimization_config)\n",
"lr = fac.build_learning_rate()\n",
"optimizer = fac.build_optimizer(lr=lr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Rp7R1hBfv5HG"
},
"outputs": [],
"source": [
"x = tf.linspace(0, num_train_steps, 1001).numpy()\n",
"y = [lr(xi) for xi in x]\n",
"plt.plot(x,y)\n",
"plt.xlabel('Train step')\n",
"plt.ylabel('Learning rate')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ywn5miD_dnuh"
},
"source": [
"The advantage to using `config` objects is that they don't contain any complicated TensorFlow objects, and can be easily serialized to JSON, and rebuilt. Here's the JSON for the above `tfm.optimization.OptimizationConfig`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zo5RV5lud81Y"
},
"outputs": [],
"source": [
"optimization_config = optimization_config.as_dict()\n",
"optimization_config"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Z6qPXPEhekkd"
},
"source": [
"The `tfm.optimization.optimizer_factory.OptimizerFactory` can just as easily build the optimizer from the JSON dictionary:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "p-bYrvfMYsxp"
},
"outputs": [],
"source": [
"fac = tfm.optimization.optimizer_factory.OptimizerFactory(\n",
" tfm.optimization.OptimizationConfig(optimization_config))\n",
"lr = fac.build_learning_rate()\n",
"optimizer = fac.build_optimizer(lr=lr)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "fine_tune_bert.ipynb",
"private_outputs": true,
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "80xnUmoI7fBX"
},
"source": [
"##### Copyright 2020 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "8nvTnfs6Q692"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WmfcMK5P5C1G"
},
"source": [
"# Introduction to the TensorFlow Models NLP library"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cH-oJ8R6AHMK"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/nlp\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/nlp/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/nlp/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/nlp/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0H_EFIhq4-MJ"
},
"source": [
"## Learning objectives\n",
"\n",
"In this Colab notebook, you will learn how to build transformer-based models for common NLP tasks including pretraining, span labelling and classification using the building blocks from [NLP modeling library](https://github.com/tensorflow/models/tree/master/official/nlp/modeling)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2N97-dps_nUk"
},
"source": [
"## Install and import"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "459ygAVl_rg0"
},
"source": [
"### Install the TensorFlow Model Garden pip package\n",
"\n",
"* `tf-models-official` is the stable Model Garden package. Note that it may not include the latest changes in the `tensorflow_models` github repo. To include latest changes, you may install `tf-models-nightly`,\n",
"which is the nightly Model Garden package created daily automatically.\n",
"* `pip` will install all models and dependencies automatically."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Y-qGkdh6_sZc"
},
"outputs": [],
"source": [
"!pip install tf-models-official"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "e4huSSwyAG_5"
},
"source": [
"### Import Tensorflow and other libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jqYXqtjBAJd9"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"\n",
"from tensorflow_models import nlp"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "djBQWjvy-60Y"
},
"source": [
"## BERT pretraining model\n",
"\n",
"BERT ([Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805)) introduced the method of pre-training language representations on a large text corpus and then using that model for downstream NLP tasks.\n",
"\n",
"In this section, we will learn how to build a model to pretrain BERT on the masked language modeling task and next sentence prediction task. For simplicity, we only show the minimum example and use dummy data."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MKuHVlsCHmiq"
},
"source": [
"### Build a `BertPretrainer` model wrapping `BertEncoder`\n",
"\n",
"The `nlp.networks.BertEncoder` class implements the Transformer-based encoder as described in [BERT paper](https://arxiv.org/abs/1810.04805). It includes the embedding lookups and transformer layers (`nlp.layers.TransformerEncoderBlock`), but not the masked language model or classification task networks.\n",
"\n",
"The `nlp.models.BertPretrainer` class allows a user to pass in a transformer stack, and instantiates the masked language model and classification networks that are used to create the training objectives."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EXkcXz-9BwB3"
},
"outputs": [],
"source": [
"# Build a small transformer network.\n",
"vocab_size = 100\n",
"network = nlp.networks.BertEncoder(\n",
" vocab_size=vocab_size, \n",
" # The number of TransformerEncoderBlock layers\n",
" num_layers=3)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0NH5irV5KTMS"
},
"source": [
"Inspecting the encoder, we see it contains few embedding layers, stacked `nlp.layers.TransformerEncoderBlock` layers and are connected to three input layers:\n",
"\n",
"`input_word_ids`, `input_type_ids` and `input_mask`.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lZNoZkBrIoff"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(network, show_shapes=True, expand_nested=True, dpi=48)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "o7eFOZXiIl-b"
},
"outputs": [],
"source": [
"# Create a BERT pretrainer with the created network.\n",
"num_token_predictions = 8\n",
"bert_pretrainer = nlp.models.BertPretrainer(\n",
" network, num_classes=2, num_token_predictions=num_token_predictions, output='predictions')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d5h5HT7gNHx_"
},
"source": [
"Inspecting the `bert_pretrainer`, we see it wraps the `encoder` with additional `MaskedLM` and `nlp.layers.ClassificationHead` heads."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2tcNfm03IBF7"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(bert_pretrainer, show_shapes=True, expand_nested=True, dpi=48)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "F2oHrXGUIS0M"
},
"outputs": [],
"source": [
"# We can feed some dummy data to get masked language model and sentence output.\n",
"sequence_length = 16\n",
"batch_size = 2\n",
"\n",
"word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
"mask_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"type_id_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"masked_lm_positions_data = np.random.randint(2, size=(batch_size, num_token_predictions))\n",
"\n",
"outputs = bert_pretrainer(\n",
" [word_id_data, mask_data, type_id_data, masked_lm_positions_data])\n",
"lm_output = outputs[\"masked_lm\"]\n",
"sentence_output = outputs[\"classification\"]\n",
"print(f'lm_output: shape={lm_output.shape}, dtype={lm_output.dtype!r}')\n",
"print(f'sentence_output: shape={sentence_output.shape}, dtype={sentence_output.dtype!r}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bnx3UCHniCS5"
},
"source": [
"### Compute loss\n",
"Next, we can use `lm_output` and `sentence_output` to compute `loss`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "k30H4Q86f52x"
},
"outputs": [],
"source": [
"masked_lm_ids_data = np.random.randint(vocab_size, size=(batch_size, num_token_predictions))\n",
"masked_lm_weights_data = np.random.randint(2, size=(batch_size, num_token_predictions))\n",
"next_sentence_labels_data = np.random.randint(2, size=(batch_size))\n",
"\n",
"mlm_loss = nlp.losses.weighted_sparse_categorical_crossentropy_loss(\n",
" labels=masked_lm_ids_data,\n",
" predictions=lm_output,\n",
" weights=masked_lm_weights_data)\n",
"sentence_loss = nlp.losses.weighted_sparse_categorical_crossentropy_loss(\n",
" labels=next_sentence_labels_data,\n",
" predictions=sentence_output)\n",
"loss = mlm_loss + sentence_loss\n",
"\n",
"print(loss)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wrmSs8GjHxVw"
},
"source": [
"With the loss, you can optimize the model.\n",
"After training, we can save the weights of TransformerEncoder for the downstream fine-tuning tasks. Please see [run_pretraining.py](https://github.com/tensorflow/models/blob/master/official/legacy/bert/run_pretraining.py) for the full example.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "k8cQVFvBCV4s"
},
"source": [
"## Span labeling model\n",
"\n",
"Span labeling is the task to assign labels to a span of the text, for example, label a span of text as the answer of a given question.\n",
"\n",
"In this section, we will learn how to build a span labeling model. Again, we use dummy data for simplicity."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xrLLEWpfknUW"
},
"source": [
"### Build a BertSpanLabeler wrapping BertEncoder\n",
"\n",
"The `nlp.models.BertSpanLabeler` class implements a simple single-span start-end predictor (that is, a model that predicts two values: a start token index and an end token index), suitable for SQuAD-style tasks.\n",
"\n",
"Note that `nlp.models.BertSpanLabeler` wraps a `nlp.networks.BertEncoder`, the weights of which can be restored from the above pretraining model.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "B941M4iUCejO"
},
"outputs": [],
"source": [
"network = nlp.networks.BertEncoder(\n",
" vocab_size=vocab_size, num_layers=2)\n",
"\n",
"# Create a BERT trainer with the created network.\n",
"bert_span_labeler = nlp.models.BertSpanLabeler(network)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QpB9pgj4PpMg"
},
"source": [
"Inspecting the `bert_span_labeler`, we see it wraps the encoder with additional `SpanLabeling` that outputs `start_position` and `end_position`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RbqRNJCLJu4H"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(bert_span_labeler, show_shapes=True, expand_nested=True, dpi=48)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fUf1vRxZJwio"
},
"outputs": [],
"source": [
"# Create a set of 2-dimensional data tensors to feed into the model.\n",
"word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
"mask_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"type_id_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"\n",
"# Feed the data to the model.\n",
"start_logits, end_logits = bert_span_labeler([word_id_data, mask_data, type_id_data])\n",
"\n",
"print(f'start_logits: shape={start_logits.shape}, dtype={start_logits.dtype!r}')\n",
"print(f'end_logits: shape={end_logits.shape}, dtype={end_logits.dtype!r}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WqhgQaN1lt-G"
},
"source": [
"### Compute loss\n",
"With `start_logits` and `end_logits`, we can compute loss:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "waqs6azNl3Nn"
},
"outputs": [],
"source": [
"start_positions = np.random.randint(sequence_length, size=(batch_size))\n",
"end_positions = np.random.randint(sequence_length, size=(batch_size))\n",
"\n",
"start_loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
" start_positions, start_logits, from_logits=True)\n",
"end_loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
" end_positions, end_logits, from_logits=True)\n",
"\n",
"total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2\n",
"print(total_loss)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Zdf03YtZmd_d"
},
"source": [
"With the `loss`, you can optimize the model. Please see [run_squad.py](https://github.com/tensorflow/models/blob/master/official/legacy/bert/run_squad.py) for the full example."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0A1XnGSTChg9"
},
"source": [
"## Classification model\n",
"\n",
"In the last section, we show how to build a text classification model.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MSK8OpZgnQa9"
},
"source": [
"### Build a BertClassifier model wrapping BertEncoder\n",
"\n",
"`nlp.models.BertClassifier` implements a [CLS] token classification model containing a single classification head."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cXXCsffkCphk"
},
"outputs": [],
"source": [
"network = nlp.networks.BertEncoder(\n",
" vocab_size=vocab_size, num_layers=2)\n",
"\n",
"# Create a BERT trainer with the created network.\n",
"num_classes = 2\n",
"bert_classifier = nlp.models.BertClassifier(\n",
" network, num_classes=num_classes)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8tZKueKYP4bB"
},
"source": [
"Inspecting the `bert_classifier`, we see it wraps the `encoder` with additional `Classification` head."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "snlutm9ZJgEZ"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(bert_classifier, show_shapes=True, expand_nested=True, dpi=48)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yyHPHsqBJkCz"
},
"outputs": [],
"source": [
"# Create a set of 2-dimensional data tensors to feed into the model.\n",
"word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
"mask_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"type_id_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"\n",
"# Feed the data to the model.\n",
"logits = bert_classifier([word_id_data, mask_data, type_id_data])\n",
"print(f'logits: shape={logits.shape}, dtype={logits.dtype!r}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "w--a2mg4nzKm"
},
"source": [
"### Compute loss\n",
"\n",
"With `logits`, we can compute `loss`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9X0S1DoFn_5Q"
},
"outputs": [],
"source": [
"labels = np.random.randint(num_classes, size=(batch_size))\n",
"\n",
"loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
" labels, logits, from_logits=True)\n",
"print(loss)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mzBqOylZo3og"
},
"source": [
"With the `loss`, you can optimize the model. Please see the [Fine tune_bert](https://www.tensorflow.org/text/tutorials/fine_tune_bert) notebook or the [model training documentation](https://github.com/tensorflow/models/blob/master/official/nlp/docs/train.md) for the full example."
]
}
],
"metadata": {
"colab": {
"name": "nlp_modeling_library_intro.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "30155835fc9f"
},
"source": [
"##### Copyright 2022 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "906e07f6e562"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5hrbPTziJK15"
},
"source": [
"# Load LM Checkpoints using Model Garden"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-PYqCW1II75I"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/nlp/load_lm_ckpts\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/nlp/load_lm_ckpts.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/nlp/load_lm_ckpts.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/nlp/load_lm_ckpts.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yyyk1KMlJdWd"
},
"source": [
"This tutorial demonstrates how to load BERT, ALBERT and ELECTRA pretrained checkpoints and use them for downstream tasks.\n",
"\n",
"[Model Garden](https://www.tensorflow.org/tfmodels) contains a collection of state-of-the-art models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uEG4RYHolQij"
},
"source": [
"## Install TF Model Garden package"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kPfC1NJZnJq1"
},
"outputs": [],
"source": [
"!pip install -U -q \"tf-models-official\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Op9R3zy3lUk8"
},
"source": [
"## Import necessary libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6_y4Rfq23wK-"
},
"outputs": [],
"source": [
"import os\n",
"import yaml\n",
"import json\n",
"\n",
"import tensorflow as tf"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xjgv3gllzbYQ"
},
"outputs": [],
"source": [
"import tensorflow_models as tfm\n",
"\n",
"from official.core import exp_factory"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "J-t2mo6VQNfY"
},
"source": [
"## Load BERT model pretrained checkpoints"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hdBsFnI20LDE"
},
"source": [
"### Select required BERT model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "apn3VgxUlr5G"
},
"outputs": [],
"source": [
"# @title Download Checkpoint of the Selected Model { display-mode: \"form\", run: \"auto\" }\n",
"model_display_name = 'BERT-base cased English' # @param ['BERT-base uncased English','BERT-base cased English','BERT-large uncased English', 'BERT-large cased English', 'BERT-large, Uncased (Whole Word Masking)', 'BERT-large, Cased (Whole Word Masking)', 'BERT-base MultiLingual','BERT-base Chinese']\n",
"\n",
"if model_display_name == 'BERT-base uncased English':\n",
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/uncased_L-12_H-768_A-12.tar.gz\"\n",
" !tar -xvf \"uncased_L-12_H-768_A-12.tar.gz\"\n",
"elif model_display_name == 'BERT-base cased English':\n",
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/cased_L-12_H-768_A-12.tar.gz\"\n",
" !tar -xvf \"cased_L-12_H-768_A-12.tar.gz\"\n",
"elif model_display_name == \"BERT-large uncased English\":\n",
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/uncased_L-24_H-1024_A-16.tar.gz\"\n",
" !tar -xvf \"uncased_L-24_H-1024_A-16.tar.gz\"\n",
"elif model_display_name == \"BERT-large cased English\":\n",
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/cased_L-24_H-1024_A-16.tar.gz\"\n",
" !tar -xvf \"cased_L-24_H-1024_A-16.tar.gz\"\n",
"elif model_display_name == \"BERT-large, Uncased (Whole Word Masking)\":\n",
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/wwm_uncased_L-24_H-1024_A-16.tar.gz\"\n",
" !tar -xvf \"wwm_uncased_L-24_H-1024_A-16.tar.gz\"\n",
"elif model_display_name == \"BERT-large, Cased (Whole Word Masking)\":\n",
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/wwm_cased_L-24_H-1024_A-16.tar.gz\"\n",
" !tar -xvf \"wwm_cased_L-24_H-1024_A-16.tar.gz\"\n",
"elif model_display_name == \"BERT-base MultiLingual\":\n",
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/multi_cased_L-12_H-768_A-12.tar.gz\"\n",
" !tar -xvf \"multi_cased_L-12_H-768_A-12.tar.gz\"\n",
"elif model_display_name == \"BERT-base Chinese\":\n",
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/chinese_L-12_H-768_A-12.tar.gz\"\n",
" !tar -xvf \"chinese_L-12_H-768_A-12.tar.gz\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jzxyziRuaC95"
},
"outputs": [],
"source": [
"# Lookup table of the directory name corresponding to each model checkpoint\n",
"folder_bert_dict = {\n",
" 'BERT-base uncased English': 'uncased_L-12_H-768_A-12',\n",
" 'BERT-base cased English': 'cased_L-12_H-768_A-12',\n",
" 'BERT-large uncased English': 'uncased_L-24_H-1024_A-16',\n",
" 'BERT-large cased English': 'cased_L-24_H-1024_A-16',\n",
" 'BERT-large, Uncased (Whole Word Masking)': 'wwm_uncased_L-24_H-1024_A-16',\n",
" 'BERT-large, Cased (Whole Word Masking)': 'wwm_cased_L-24_H-1024_A-16',\n",
" 'BERT-base MultiLingual': 'multi_cased_L-12_H-768_A-1',\n",
" 'BERT-base Chinese': 'chinese_L-12_H-768_A-12'\n",
"}\n",
"\n",
"folder_bert = folder_bert_dict.get(model_display_name)\n",
"folder_bert"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "q1WrYswpZPlc"
},
"source": [
"### Construct BERT Model Using the New `params.yaml`\n",
"\n",
"params.yaml can be used for training with the bundled trainer in addition to constructing the BERT encoder here."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "quu1s8Hi2szo"
},
"outputs": [],
"source": [
"config_file = os.path.join(folder_bert, \"params.yaml\")\n",
"config_dict = yaml.safe_load(tf.io.gfile.GFile(config_file).read())\n",
"config_dict"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3t8o0iG9v8ac"
},
"outputs": [],
"source": [
"# Method 1: pass encoder config dict into EncoderConfig\n",
"encoder_config = tfm.nlp.encoders.EncoderConfig(config_dict[\"task\"][\"model\"][\"encoder\"])\n",
"encoder_config.get().as_dict()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2I5PetB6wPvb"
},
"outputs": [],
"source": [
"# Method 2: use override_params_dict function to override default Encoder params\n",
"encoder_config = tfm.nlp.encoders.EncoderConfig()\n",
"tfm.hyperparams.override_params_dict(encoder_config, config_dict[\"task\"][\"model\"][\"encoder\"], is_strict=True)\n",
"encoder_config.get().as_dict()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5yHiG_9oS3Uw"
},
"source": [
"### Construct BERT Model Using the Old `bert_config.json`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WEyaqLcW3nne"
},
"outputs": [],
"source": [
"bert_config_file = os.path.join(folder_bert, \"bert_config.json\")\n",
"config_dict = json.loads(tf.io.gfile.GFile(bert_config_file).read())\n",
"config_dict"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xSIcaW9tdrl4"
},
"outputs": [],
"source": [
"encoder_config = tfm.nlp.encoders.EncoderConfig({\n",
" 'type':'bert',\n",
" 'bert': config_dict\n",
"})\n",
"\n",
"encoder_config.get().as_dict()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yZznAP--TDLe"
},
"source": [
"### Construct a classifier with `encoder_config`\n",
"\n",
"Here, we construct a new BERT Classifier with 2 classes and plot its model architecture. A BERT Classifier consists of a BERT encoder using the selected encoder config, a Dropout layer and a MLP classification head."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ny962I8nqs4n"
},
"outputs": [],
"source": [
"bert_encoder = tfm.nlp.encoders.build_encoder(encoder_config)\n",
"bert_classifier = tfm.nlp.models.BertClassifier(network=bert_encoder, num_classes=2)\n",
"\n",
"tf.keras.utils.plot_model(bert_classifier)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IStKfxXkTJMu"
},
"source": [
"### Load Pretrained Weights into the BERT Classifier\n",
"\n",
"The provided pretrained checkpoint only contains weights for the BERT Encoder within the BERT Classifier. Weights for the Classification Head is still randomly initialized."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "G9_XCBpOEo4y"
},
"outputs": [],
"source": [
"checkpoint = tf.train.Checkpoint(encoder=bert_encoder)\n",
"checkpoint.read(\n",
" os.path.join(folder_bert, 'bert_model.ckpt')).expect_partial().assert_existing_objects_matched()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E6Hu1FFgQWUU"
},
"source": [
"## Load ALBERT model pretrained checkpoints"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TWUtFeWxQn0V"
},
"outputs": [],
"source": [
"# @title Download Checkpoint of the Selected Model { display-mode: \"form\", run: \"auto\" }\n",
"albert_model_display_name = 'ALBERT-xxlarge English' # @param ['ALBERT-base English', 'ALBERT-large English', 'ALBERT-xlarge English', 'ALBERT-xxlarge English']\n",
"\n",
"if albert_model_display_name == 'ALBERT-base English':\n",
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/albert/albert_base.tar.gz\"\n",
" !tar -xvf \"albert_base.tar.gz\"\n",
"elif albert_model_display_name == 'ALBERT-large English':\n",
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/albert/albert_large.tar.gz\"\n",
" !tar -xvf \"albert_large.tar.gz\"\n",
"elif albert_model_display_name == \"ALBERT-xlarge English\":\n",
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/albert/albert_xlarge.tar.gz\"\n",
" !tar -xvf \"albert_xlarge.tar.gz\"\n",
"elif albert_model_display_name == \"ALBERT-xxlarge English\":\n",
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/albert/albert_xxlarge.tar.gz\"\n",
" !tar -xvf \"albert_xxlarge.tar.gz\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5lZDWD7zUAAO"
},
"outputs": [],
"source": [
"# Lookup table of the directory name corresponding to each model checkpoint\n",
"folder_albert_dict = {\n",
" 'ALBERT-base English': 'albert_base',\n",
" 'ALBERT-large English': 'albert_large',\n",
" 'ALBERT-xlarge English': 'albert_xlarge',\n",
" 'ALBERT-xxlarge English': 'albert_xxlarge'\n",
"}\n",
"\n",
"folder_albert = folder_albert_dict.get(albert_model_display_name)\n",
"folder_albert"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ftXwmObdU2fS"
},
"source": [
"### Construct ALBERT Model Using the New `params.yaml`\n",
"\n",
"params.yaml can be used for training with the bundled trainer in addition to constructing the BERT encoder here."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VXn20q2oU1UJ"
},
"outputs": [],
"source": [
"config_file = os.path.join(folder_albert, \"params.yaml\")\n",
"config_dict = yaml.safe_load(tf.io.gfile.GFile(config_file).read())\n",
"config_dict"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Uo_TSMSvWOX_"
},
"outputs": [],
"source": [
"# Method 1: pass encoder config dict into EncoderConfig\n",
"encoder_config = tfm.nlp.encoders.EncoderConfig(config_dict[\"task\"][\"model\"][\"encoder\"])\n",
"encoder_config.get().as_dict()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u7oJe93uWcy0"
},
"outputs": [],
"source": [
"# Method 2: use override_params_dict function to override default Encoder params\n",
"encoder_config = tfm.nlp.encoders.EncoderConfig()\n",
"tfm.hyperparams.override_params_dict(encoder_config, config_dict[\"task\"][\"model\"][\"encoder\"], is_strict=True)\n",
"encoder_config.get().as_dict()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "abpQFw80Wx6c"
},
"source": [
"### Construct ALBERT Model Using the Old `albert_config.json`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Xb99qms6WuPa"
},
"outputs": [],
"source": [
"albert_config_file = os.path.join(folder_albert, \"albert_config.json\")\n",
"config_dict = json.loads(tf.io.gfile.GFile(albert_config_file).read())\n",
"config_dict"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mCW0RJHcEtVV"
},
"outputs": [],
"source": [
"encoder_config = tfm.nlp.encoders.EncoderConfig({\n",
" 'type':'albert',\n",
" 'albert': config_dict\n",
"})\n",
"\n",
"encoder_config.get().as_dict()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EIAMaOxdZw5u"
},
"source": [
"### Construct a Classifier with `encoder_config`\n",
"\n",
"Here, we construct a new BERT Classifier with 2 classes and plot its model architecture. A BERT Classifier consists of a BERT encoder using the selected encoder config, a Dropout layer and a MLP classification head."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xTkUisEEFEey"
},
"outputs": [],
"source": [
"albert_encoder = tfm.nlp.encoders.build_encoder(encoder_config)\n",
"albert_classifier = tfm.nlp.models.BertClassifier(network=albert_encoder, num_classes=2)\n",
"\n",
"tf.keras.utils.plot_model(albert_classifier)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m6EG_7CaZ2rI"
},
"source": [
"### Load Pretrained Weights into the Classifier\n",
"\n",
"The provided pretrained checkpoint only contains weights for the ALBERT Encoder within the ALBERT Classifier. Weights for the Classification Head is still randomly initialized."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7dOG3agXZ9Dx"
},
"outputs": [],
"source": [
"checkpoint = tf.train.Checkpoint(encoder=albert_encoder)\n",
"checkpoint.read(\n",
" os.path.join(folder_albert, 'bert_model.ckpt')).expect_partial().assert_existing_objects_matched()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6xsbeS-EcCqu"
},
"source": [
"## Load ELECTRA model pretrained checkpoints"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VpwIrAR4cIBF"
},
"outputs": [],
"source": [
"# @title Download Checkpoint of the Selected Model { display-mode: \"form\", run: \"auto\" }\n",
"electra_model_display_name = 'ELECTRA-small English' # @param ['ELECTRA-small English', 'ELECTRA-base English']\n",
"\n",
"if electra_model_display_name == 'ELECTRA-small English':\n",
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/electra/small.tar.gz\"\n",
" !tar -xvf \"small.tar.gz\"\n",
"elif electra_model_display_name == 'ELECTRA-base English':\n",
" !wget \"https://storage.googleapis.com/tf_model_garden/nlp/electra/base.tar.gz\"\n",
" !tar -xvf \"base.tar.gz\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fy4FmsNOhlNa"
},
"outputs": [],
"source": [
"# Lookup table of the directory name corresponding to each model checkpoint\n",
"folder_electra_dict = {\n",
" 'ELECTRA-small English': 'small',\n",
" 'ELECTRA-base English': 'base'\n",
"}\n",
"\n",
"folder_electra = folder_electra_dict.get(electra_model_display_name)\n",
"folder_electra"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rgAcf-Fl3RTG"
},
"source": [
"### Construct BERT Model Using the `params.yaml`\n",
"\n",
"params.yaml can be used for training with the bundled trainer in addition to constructing the BERT encoder here."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZNBg5xzqh0Gr"
},
"outputs": [],
"source": [
"config_file = os.path.join(folder_electra, \"params.yaml\")\n",
"config_dict = yaml.safe_load(tf.io.gfile.GFile(config_file).read())\n",
"config_dict"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "i-yX-KgJyduv"
},
"outputs": [],
"source": [
"disc_encoder_config = tfm.nlp.encoders.EncoderConfig(\n",
" config_dict['model']['discriminator_encoder']\n",
")\n",
"\n",
"disc_encoder_config.get().as_dict()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1AdrMkH73VYz"
},
"source": [
"### Construct a Classifier with `encoder_config`\n",
"\n",
"Here, we construct a Classifier with 2 classes and plot its model architecture. A Classifier consists of a ELECTRA discriminator encoder using the selected encoder config, a Dropout layer and a MLP classification head.\n",
"\n",
"**Note**: The generator is discarded and the discriminator is used for downstream tasks"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "98Pt-SxszAvN"
},
"outputs": [],
"source": [
"disc_encoder = tfm.nlp.encoders.build_encoder(disc_encoder_config)\n",
"elctra_dic_classifier = tfm.nlp.models.BertClassifier(network=disc_encoder, num_classes=2)\n",
"tf.keras.utils.plot_model(elctra_dic_classifier)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aWQ2FKj64X5U"
},
"source": [
"### Load Pretrained Weights into the Classifier\n",
"\n",
"The provided pretrained checkpoint contains weights for the entire ELECTRA model. We are only loading its discriminator (conveninently named as `encoder`) wights within the Classifier. Weights for the Classification Head is still randomly initialized."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "99pznFJszQfV"
},
"outputs": [],
"source": [
"checkpoint = tf.train.Checkpoint(encoder=disc_encoder)\n",
"checkpoint.read(\n",
" tf.train.latest_checkpoint(os.path.join(folder_electra))\n",
" ).expect_partial().assert_existing_objects_matched()"
]
}
],
"metadata": {
"colab": {
"name": "load_lm_ckpts.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "Tce3stUlHN0L"
},
"source": [
"##### Copyright 2020 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "tuOe1ymfHZPu"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qFdPvlXBOdUN"
},
"source": [
"# Training with Orbit"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MfBg1C5NB3X0"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/orbit\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/orbit/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/orbit/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/orbit/index.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "456h0idS2Xcq"
},
"source": [
"This example will work through fine-tuning a BERT model using the [Orbit](https://www.tensorflow.org/api_docs/python/orbit) training library.\n",
"\n",
"Orbit is a flexible, lightweight library designed to make it easy to write [custom training loops](https://www.tensorflow.org/tutorials/distribute/custom_training) in TensorFlow. Orbit handles common model training tasks such as saving checkpoints, running model evaluations, and setting up summary writing, while giving users full control over implementing the inner training loop. It integrates with `tf.distribute` and supports running on different device types (CPU, GPU, and TPU).\n",
"\n",
"Most examples on [tensorflow.org](https://www.tensorflow.org/) use custom training loops or [model.fit()](https://www.tensorflow.org/api_docs/python/tf/keras/Model) from Keras. Orbit is a good alternative to `model.fit` if your model is complex and your training loop requires more flexibility, control, or customization. Also, using Orbit can simplify the code when there are many different model architectures that all use the same custom training loop.\n",
"\n",
"This tutorial focuses on setting up and using Orbit, rather than details about BERT, model construction, and data processing. For more in-depth tutorials on these topics, refer to the following tutorials:\n",
"\n",
"* [Fine tune BERT](https://www.tensorflow.org/text/tutorials/fine_tune_bert) - which goes into detail on these sub-topics.\n",
"* [Fine tune BERT for GLUE on TPU](https://www.tensorflow.org/text/tutorials/bert_glue) - which generalizes the code to run any BERT configuration on any [GLUE](https://www.tensorflow.org/datasets/catalog/glue) sub-task, and runs on TPU."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TJ4m3khW3p_W"
},
"source": [
"## Install the TensorFlow Models package\n",
"\n",
"Install and import the necessary packages, then configure all the objects necessary for training a model.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FZlj0U8Aq9Gt"
},
"outputs": [],
"source": [
"!pip install -q opencv-python\n",
"!pip install tensorflow>=2.9.0 tf-models-official"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MEJkRrmapr16"
},
"source": [
"The `tf-models-official` package contains both the `orbit` and `tensorflow_models` modules."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dUVPW84Zucuq"
},
"outputs": [],
"source": [
"import tensorflow_models as tfm\n",
"import orbit"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "18Icocf3lwYD"
},
"source": [
"## Setup for training\n",
"\n",
"This tutorial does not focus on configuring the environment, building the model and optimizer, and loading data. All these techniques are covered in more detail in the [Fine tune BERT](https://www.tensorflow.org/text/tutorials/fine_tune_bert) and [Fine tune BERT with GLUE](https://www.tensorflow.org/text/tutorials/bert_glue) tutorials.\n",
"\n",
"To view how the training is set up for this tutorial, expand the rest of this section.\n",
"\n",
" \u003c!-- \u003cdiv class=\"tfo-display-only-on-site\"\u003e\u003cdevsite-expandable\u003e\n",
" \u003cbutton type=\"button\" class=\"button-red button expand-control\"\u003eExpand Section\u003c/button\u003e --\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ljy0z-i3okCS"
},
"source": [
"### Import the necessary packages\n",
"\n",
"Import the BERT model and dataset building library from [Tensorflow Model Garden](https://github.com/tensorflow/models)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gCBo6wxA2b5n"
},
"outputs": [],
"source": [
"import glob\n",
"import os\n",
"import pathlib\n",
"import tempfile\n",
"import time\n",
"\n",
"import numpy as np\n",
"\n",
"import tensorflow as tf"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PG1kwhnvq3VC"
},
"outputs": [],
"source": [
"from official.nlp.data import sentence_prediction_dataloader\n",
"from official.nlp import optimization"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PsbhUV_p3wxN"
},
"source": [
"### Configure the distribution strategy\n",
"\n",
"While `tf.distribute` won't help the model's runtime if you're running on a single machine or GPU, it's necessary for TPUs. Setting up a distribution strategy allows you to use the same code regardless of the configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PG702dqstXIk"
},
"outputs": [],
"source": [
"logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
"\n",
"if 'GPU' in ''.join(logical_device_names):\n",
" strategy = tf.distribute.MirroredStrategy()\n",
"elif 'TPU' in ''.join(logical_device_names):\n",
" resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')\n",
" tf.config.experimental_connect_to_cluster(resolver)\n",
" tf.tpu.experimental.initialize_tpu_system(resolver)\n",
" strategy = tf.distribute.TPUStrategy(resolver)\n",
"else:\n",
" strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eaQgM98deAMu"
},
"source": [
"For more information about the TPU setup, refer to the [TPU guide](https://www.tensorflow.org/guide/tpu)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7aOxMLLV32Zm"
},
"source": [
"### Create a model and an optimizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YRdWzOfK3_56"
},
"outputs": [],
"source": [
"max_seq_length = 128\n",
"learning_rate = 3e-5\n",
"num_train_epochs = 3\n",
"train_batch_size = 32\n",
"eval_batch_size = 64\n",
"\n",
"train_data_size = 3668\n",
"steps_per_epoch = int(train_data_size / train_batch_size)\n",
"\n",
"train_steps = steps_per_epoch * num_train_epochs\n",
"warmup_steps = int(train_steps * 0.1)\n",
"\n",
"print(\"train batch size: \", train_batch_size)\n",
"print(\"train epochs: \", num_train_epochs)\n",
"print(\"steps_per_epoch: \", steps_per_epoch)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BVw3886Ysse6"
},
"outputs": [],
"source": [
"model_dir = pathlib.Path(tempfile.mkdtemp())\n",
"print(model_dir)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mu9cV7ew-cVe"
},
"source": [
"\n",
"Create a BERT Classifier model and a simple optimizer. They must be created inside `strategy.scope` so that the variables can be distributed. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gmwtX0cp-mj5"
},
"outputs": [],
"source": [
"with strategy.scope():\n",
" encoder_network = tfm.nlp.encoders.build_encoder(\n",
" tfm.nlp.encoders.EncoderConfig(type=\"bert\"))\n",
" classifier_model = tfm.nlp.models.BertClassifier(\n",
" network=encoder_network, num_classes=2)\n",
"\n",
" optimizer = optimization.create_optimizer(\n",
" init_lr=3e-5,\n",
" num_train_steps=steps_per_epoch * num_train_epochs,\n",
" num_warmup_steps=warmup_steps,\n",
" end_lr=0.0,\n",
" optimizer_type='adamw')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jwJSfewG5jVV"
},
"outputs": [],
"source": [
"tf.keras.utils.plot_model(classifier_model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IQy5pYgAf8Ft"
},
"source": [
"### Initialize from a Checkpoint"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6CE14GEybgRR"
},
"outputs": [],
"source": [
"bert_dir = 'gs://cloud-tpu-checkpoints/bert/v3/uncased_L-12_H-768_A-12/'\n",
"tf.io.gfile.listdir(bert_dir)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "x7fwxz9xidKt"
},
"outputs": [],
"source": [
"bert_checkpoint = bert_dir + 'bert_model.ckpt'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "q7EfwVCRe7N_"
},
"outputs": [],
"source": [
"def init_from_ckpt_fn():\n",
" init_checkpoint = tf.train.Checkpoint(**classifier_model.checkpoint_items)\n",
" with strategy.scope():\n",
" (init_checkpoint\n",
" .read(bert_checkpoint)\n",
" .expect_partial()\n",
" .assert_existing_objects_matched())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "M0LUMlsde-2f"
},
"outputs": [],
"source": [
"with strategy.scope():\n",
" init_from_ckpt_fn()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gAuns4vN_IYV"
},
"source": [
"\n",
"To use Orbit, create a `tf.train.CheckpointManager` object."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "i7NwM1Jq_MX7"
},
"outputs": [],
"source": [
"checkpoint = tf.train.Checkpoint(model=classifier_model, optimizer=optimizer)\n",
"checkpoint_manager = tf.train.CheckpointManager(\n",
" checkpoint,\n",
" directory=model_dir,\n",
" max_to_keep=5,\n",
" step_counter=optimizer.iterations,\n",
" checkpoint_interval=steps_per_epoch,\n",
" init_fn=init_from_ckpt_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nzeiAFhcCOAo"
},
"source": [
"### Create distributed datasets\n",
"\n",
"As a shortcut for this tutorial, the [GLUE/MPRC dataset](https://www.tensorflow.org/datasets/catalog/glue#gluemrpc) has been converted to a pair of [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) files containing serialized `tf.train.Example` protos.\n",
"\n",
"The data was converted using [this script](https://github.com/tensorflow/models/blob/r2.9.0/official/nlp/data/create_finetuning_data.py).\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZVfbiT1dCnDk"
},
"outputs": [],
"source": [
"train_data_path = \"gs://download.tensorflow.org/data/model_garden_colab/mrpc_train.tf_record\"\n",
"eval_data_path = \"gs://download.tensorflow.org/data/model_garden_colab/mrpc_eval.tf_record\"\n",
"\n",
"def _dataset_fn(input_file_pattern, \n",
" global_batch_size, \n",
" is_training, \n",
" input_context=None):\n",
" data_config = sentence_prediction_dataloader.SentencePredictionDataConfig(\n",
" input_path=input_file_pattern,\n",
" seq_length=max_seq_length,\n",
" global_batch_size=global_batch_size,\n",
" is_training=is_training)\n",
" return sentence_prediction_dataloader.SentencePredictionDataLoader(\n",
" data_config).load(input_context=input_context)\n",
"\n",
"train_dataset = orbit.utils.make_distributed_dataset(\n",
" strategy, _dataset_fn, input_file_pattern=train_data_path,\n",
" global_batch_size=train_batch_size, is_training=True)\n",
"eval_dataset = orbit.utils.make_distributed_dataset(\n",
" strategy, _dataset_fn, input_file_pattern=eval_data_path,\n",
" global_batch_size=eval_batch_size, is_training=False)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dPgiDBQCjsXW"
},
"source": [
"### Create a loss function\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7MCUmmo2jvXl"
},
"outputs": [],
"source": [
"def loss_fn(labels, logits):\n",
" \"\"\"Classification loss.\"\"\"\n",
" labels = tf.squeeze(labels)\n",
" log_probs = tf.nn.log_softmax(logits, axis=-1)\n",
" one_hot_labels = tf.one_hot(\n",
" tf.cast(labels, dtype=tf.int32), depth=2, dtype=tf.float32)\n",
" per_example_loss = -tf.reduce_sum(\n",
" tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1)\n",
" return tf.reduce_mean(per_example_loss)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ohlO-8FQkwsr"
},
"source": [
" \u003c/devsite-expandable\u003e\u003c/div\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ymhbvPaEJ96T"
},
"source": [
"## Controllers, Trainers and Evaluators\n",
"\n",
"When using Orbit, the `orbit.Controller` class drives the training. The Controller handles the details of distribution strategies, step counting, TensorBoard summaries, and checkpointing.\n",
"\n",
"To implement the training and evaluation, pass a `trainer` and `evaluator`, which are subclass instances of `orbit.AbstractTrainer` and `orbit.AbstractEvaluator`. Keeping with Orbit's light-weight design, these two classes have a minimal interface.\n",
"\n",
"The Controller drives training and evaluation by calling `trainer.train(num_steps)` and `evaluator.evaluate(num_steps)`. These `train` and `evaluate` methods return a dictionary of results for logging.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a6sU2vBeyXtu"
},
"source": [
"Training is broken into chunks of length `num_steps`. This is set by the Controller's [`steps_per_loop`](https://tensorflow.org/api_docs/python/orbit/Controller#args) argument. With the trainer and evaluator abstract base classes, the meaning of `num_steps` is entirely determined by the implementer.\n",
"\n",
"Some common examples include:\n",
"\n",
"* Having the chunks represent dataset-epoch boundaries, like the default keras setup. \n",
"* Using it to more efficiently dispatch a number of training steps to an accelerator with a single `tf.function` call (like the `steps_per_execution` argument to `Model.compile`). \n",
"* Subdividing into smaller chunks as needed.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "p4mXGIRJsf1j"
},
"source": [
"### StandardTrainer and StandardEvaluator\n",
"\n",
"Orbit provides two additional classes, `orbit.StandardTrainer` and `orbit.StandardEvaluator`, to give more structure around the training and evaluation loops.\n",
"\n",
"With StandardTrainer, you only need to set `train_loop_begin`, `train_step`, and `train_loop_end`. The base class handles the loops, dataset logic, and `tf.function` (according to the options set by their `orbit.StandardTrainerOptions`). This is simpler than `orbit.AbstractTrainer`, which requires you to handle the entire loop. StandardEvaluator has a similar structure and simplification to StandardTrainer.\n",
"\n",
"This is effectively an implementation of the `steps_per_execution` approach used by Keras."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-hvZ8PvohmR5"
},
"source": [
"Contrast this with Keras, where training is divided both into epochs (a single pass over the dataset) and `steps_per_execution`(set within [`Model.compile`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#compile). In Keras, metric averages are typically accumulated over an epoch, and reported \u0026 reset between epochs. For efficiency, `steps_per_execution` only controls the number of training steps made per call.\n",
"\n",
"In this simple case, `steps_per_loop` (within `StandardTrainer`) will handle both the metric resets and the number of steps per call. \n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NoDFN1L-1jIu"
},
"source": [
"The minimal setup when using these base classes is to implement the methods as follows:\n",
"\n",
"1. `StandardTrainer.train_loop_begin` - Reset your training metrics.\n",
"2. `StandardTrainer.train_step` - Apply a single gradient update.\n",
"3. `StandardTrainer.train_loop_end` - Report your training metrics.\n",
"\n",
"and\n",
"\n",
"4. `StandardEvaluator.eval_begin` - Reset your evaluation metrics.\n",
"5. `StandardEvaluator.eval_step` - Run a single evaluation setep.\n",
"6. `StandardEvaluator.eval_reduce` - This is not necessary in this simple setup.\n",
"7. `StandardEvaluator.eval_end` - Report your evaluation metrics.\n",
"\n",
"Depending on the settings, the base class may wrap the `train_step` and `eval_step` code in `tf.function` or `tf.while_loop`, which has some limitations compared to standard python."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3KPA0NDZt2JD"
},
"source": [
"### Define the trainer class"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6LDPsvJwfuPR"
},
"source": [
"In this section you'll create a subclass of `orbit.StandardTrainer` for this task. \n",
"\n",
"Note: To better explain the `BertClassifierTrainer` class, this section defines each method as a stand-alone function and assembles them into a class at the end.\n",
"\n",
"The trainer needs access to the training data, model, optimizer, and distribution strategy. Pass these as arguments to the initializer.\n",
"\n",
"Define a single training metric, `training_loss`, using `tf.keras.metrics.Mean`. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6DQYZN5ax-MG"
},
"outputs": [],
"source": [
"def trainer_init(self,\n",
" train_dataset,\n",
" model,\n",
" optimizer,\n",
" strategy):\n",
" self.strategy = strategy\n",
" with self.strategy.scope():\n",
" self.model = model\n",
" self.optimizer = optimizer\n",
" self.global_step = self.optimizer.iterations\n",
" \n",
"\n",
" self.train_loss = tf.keras.metrics.Mean(\n",
" 'training_loss', dtype=tf.float32)\n",
" orbit.StandardTrainer.__init__(self, train_dataset)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QOwHD7U5hVue"
},
"source": [
"Before starting a run of the training loop, the `train_loop_begin` method will reset the `train_loss` metric."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AkpcHqXShWL0"
},
"outputs": [],
"source": [
"def train_loop_begin(self):\n",
" self.train_loss.reset_states()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UjtFOFyxn2BB"
},
"source": [
"The `train_step` is a straight-forward loss-calculation and gradient update that is run by the distribution strategy. This is accomplished by defining the gradient step as a nested function (`step_fn`).\n",
"\n",
"The method receives `tf.distribute.DistributedIterator` to handle the [distributed input](https://www.tensorflow.org/tutorials/distribute/input). The method uses `Strategy.run` to execute `step_fn` and feeds it from the distributed iterator.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QuPwNnT5I-GP"
},
"outputs": [],
"source": [
"def train_step(self, iterator):\n",
"\n",
" def step_fn(inputs):\n",
" labels = inputs.pop(\"label_ids\")\n",
" with tf.GradientTape() as tape:\n",
" model_outputs = self.model(inputs, training=True)\n",
" # Raw loss is used for reporting in metrics/logs.\n",
" raw_loss = loss_fn(labels, model_outputs)\n",
" # Scales down the loss for gradients to be invariant from replicas.\n",
" loss = raw_loss / self.strategy.num_replicas_in_sync\n",
"\n",
" grads = tape.gradient(loss, self.model.trainable_variables)\n",
" optimizer.apply_gradients(zip(grads, self.model.trainable_variables))\n",
" # For reporting, the metric takes the mean of losses.\n",
" self.train_loss.update_state(raw_loss)\n",
"\n",
" self.strategy.run(step_fn, args=(next(iterator),))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VmQNwx5QpyDt"
},
"source": [
"The `orbit.StandardTrainer` handles the `@tf.function` and loops.\n",
"\n",
"After running through `num_steps` of training, `StandardTrainer` calls `train_loop_end`. The function returns the metric results:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GqCyVk1zzGod"
},
"outputs": [],
"source": [
"def train_loop_end(self):\n",
" return {\n",
" self.train_loss.name: self.train_loss.result(),\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xvmLONl80KUv"
},
"source": [
"Build a subclass of `orbit.StandardTrainer` with those methods."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oRoL7VE6xt1G"
},
"outputs": [],
"source": [
"class BertClassifierTrainer(orbit.StandardTrainer):\n",
" __init__ = trainer_init\n",
" train_loop_begin = train_loop_begin\n",
" train_step = train_step\n",
" train_loop_end = train_loop_end"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yjG4QAWj1B00"
},
"source": [
"### Define the evaluator class\n",
"\n",
"Note: Like the previous section, this section defines each method as a stand-alone function and assembles them into a `BertClassifierEvaluator` class at the end.\n",
"\n",
"The evaluator is even simpler for this task. It needs access to the evaluation dataset, the model, and the strategy. After saving references to those objects, the constructor just needs to create the metrics."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cvX7seCY1CWj"
},
"outputs": [],
"source": [
"def evaluator_init(self,\n",
" eval_dataset,\n",
" model,\n",
" strategy):\n",
" self.strategy = strategy\n",
" with self.strategy.scope():\n",
" self.model = model\n",
" \n",
" self.eval_loss = tf.keras.metrics.Mean(\n",
" 'evaluation_loss', dtype=tf.float32)\n",
" self.eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
" name='accuracy', dtype=tf.float32)\n",
" orbit.StandardEvaluator.__init__(self, eval_dataset)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0r-z-XK7ybyX"
},
"source": [
"Similar to the trainer, the `eval_begin` and `eval_end` methods just need to reset the metrics before the loop and then report the results after the loop."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7VVb0Tg6yZjI"
},
"outputs": [],
"source": [
"def eval_begin(self):\n",
" self.eval_accuracy.reset_states()\n",
" self.eval_loss.reset_states()\n",
"\n",
"def eval_end(self):\n",
" return {\n",
" self.eval_accuracy.name: self.eval_accuracy.result(),\n",
" self.eval_loss.name: self.eval_loss.result(),\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iDOZcQvttdmZ"
},
"source": [
"The `eval_step` method works like `train_step`. The inner `step_fn` defines the actual work of calculating the loss \u0026 accuracy and updating the metrics. The outer `eval_step` receives `tf.distribute.DistributedIterator` as input, and uses `Strategy.run` to launch the distributed execution to `step_fn`, feeding it from the distributed iterator."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JLJnYuuGJjvd"
},
"outputs": [],
"source": [
"def eval_step(self, iterator):\n",
"\n",
" def step_fn(inputs):\n",
" labels = inputs.pop(\"label_ids\")\n",
" model_outputs = self.model(inputs, training=True)\n",
" loss = loss_fn(labels, model_outputs)\n",
" self.eval_loss.update_state(loss)\n",
" self.eval_accuracy.update_state(labels, model_outputs)\n",
"\n",
" self.strategy.run(step_fn, args=(next(iterator),))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Gt3hh0V30QcP"
},
"source": [
"Build a subclass of `orbit.StandardEvaluator` with those methods."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3zqyLxfNyCgA"
},
"outputs": [],
"source": [
"class BertClassifierEvaluator(orbit.StandardEvaluator):\n",
" __init__ = evaluator_init\n",
" eval_begin = eval_begin\n",
" eval_end = eval_end\n",
" eval_step = eval_step"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aK9gEja9qPOc"
},
"source": [
"### End-to-end training and evaluation\n",
"\n",
"To run the training and evaluation, simply create the trainer, evaluator, and `orbit.Controller` instances. Then call the `Controller.train_and_evaluate` method."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PqQetxyXqRA9"
},
"outputs": [],
"source": [
"trainer = BertClassifierTrainer(\n",
" train_dataset, classifier_model, optimizer, strategy)\n",
"\n",
"evaluator = BertClassifierEvaluator(\n",
" eval_dataset, classifier_model, strategy)\n",
"\n",
"controller = orbit.Controller(\n",
" trainer=trainer,\n",
" evaluator=evaluator,\n",
" global_step=trainer.global_step,\n",
" steps_per_loop=20,\n",
" checkpoint_manager=checkpoint_manager)\n",
"\n",
"result = controller.train_and_evaluate(\n",
" train_steps=steps_per_epoch * num_train_epochs,\n",
" eval_steps=-1,\n",
" eval_interval=steps_per_epoch)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [
"Tce3stUlHN0L"
],
"name": "Orbit Tutorial.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
toc:
- title: "Example: Image classification"
path: /tfmodels/vision/image_classification
- title: "Example: Object Detection"
path: /tfmodels/vision/object_detection
- title: "Example: Semantic Segmentation"
path: /tfmodels/vision/semantic_segmentation
- title: "Example: Instance Segmentation"
path: /tfmodels/vision/instance_segmentation
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "Tce3stUlHN0L"
},
"source": [
"##### Copyright 2020 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "tuOe1ymfHZPu"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qFdPvlXBOdUN"
},
"source": [
"# Image classification with Model Garden"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MfBg1C5NB3X0"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/vision/image_classification\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/vision/image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/vision/image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/vision/image_classification.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ta_nFXaVAqLD"
},
"source": [
"This tutorial fine-tunes a Residual Network (ResNet) from the TensorFlow [Model Garden](https://github.com/tensorflow/models) package (`tensorflow-models`) to classify images in the [CIFAR](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.\n",
"\n",
"Model Garden contains a collection of state-of-the-art vision models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.\n",
"\n",
"This tutorial uses a [ResNet](https://arxiv.org/pdf/1512.03385.pdf) model, a state-of-the-art image classifier. This tutorial uses the ResNet-18 model, a convolutional neural network with 18 layers.\n",
"\n",
"This tutorial demonstrates how to:\n",
"1. Use models from the TensorFlow Models package.\n",
"2. Fine-tune a pre-built ResNet for image classification.\n",
"3. Export the tuned ResNet model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G2FlaQcEPOER"
},
"source": [
"## Setup\n",
"\n",
"Install and import the necessary modules."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XvWfdCrvrV5W"
},
"outputs": [],
"source": [
"!pip install -U -q \"tf-models-official\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CKYMTPjOE400"
},
"source": [
"Import TensorFlow, TensorFlow Datasets, and a few helper libraries."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Wlon1uoIowmZ"
},
"outputs": [],
"source": [
"import pprint\n",
"import tempfile\n",
"\n",
"from IPython import display\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import tensorflow as tf\n",
"import tensorflow_datasets as tfds"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AVTs0jDd1b24"
},
"source": [
"The `tensorflow_models` package contains the ResNet vision model, and the `official.vision.serving` model contains the function to save and export the tuned model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NHT1iiIiBzlC"
},
"outputs": [],
"source": [
"import tensorflow_models as tfm\n",
"\n",
"# These are not in the tfm public API for v2.9. They will be available in v2.10\n",
"from official.vision.serving import export_saved_model_lib\n",
"import official.core.train_lib"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aKv3wdqkQ8FU"
},
"source": [
"## Configure the ResNet-18 model for the Cifar-10 dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5iN8mHEJjKYE"
},
"source": [
"The CIFAR10 dataset contains 60,000 color images in mutually exclusive 10 classes, with 6,000 images in each class.\n",
"\n",
"In Model Garden, the collections of parameters that define a model are called *configs*. Model Garden can create a config based on a known set of parameters via a [factory](https://en.wikipedia.org/wiki/Factory_method_pattern).\n",
"\n",
"Use the `resnet_imagenet` factory configuration, as defined by `tfm.vision.configs.image_classification.image_classification_imagenet`. The configuration is set up to train ResNet to converge on [ImageNet](https://www.image-net.org/)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1M77f88Dj2Td"
},
"outputs": [],
"source": [
"exp_config = tfm.core.exp_factory.get_exp_config('resnet_imagenet')\n",
"tfds_name = 'cifar10'\n",
"ds,ds_info = tfds.load(\n",
"tfds_name,\n",
"with_info=True)\n",
"ds_info"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "U6PVwXA-j3E7"
},
"source": [
"Adjust the model and dataset configurations so that it works with Cifar-10 (`cifar10`)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YWI7faVStQaV"
},
"outputs": [],
"source": [
"# Configure model\n",
"exp_config.task.model.num_classes = 10\n",
"exp_config.task.model.input_size = list(ds_info.features[\"image\"].shape)\n",
"exp_config.task.model.backbone.resnet.model_id = 18\n",
"\n",
"# Configure training and testing data\n",
"batch_size = 128\n",
"\n",
"exp_config.task.train_data.input_path = ''\n",
"exp_config.task.train_data.tfds_name = tfds_name\n",
"exp_config.task.train_data.tfds_split = 'train'\n",
"exp_config.task.train_data.global_batch_size = batch_size\n",
"\n",
"exp_config.task.validation_data.input_path = ''\n",
"exp_config.task.validation_data.tfds_name = tfds_name\n",
"exp_config.task.validation_data.tfds_split = 'test'\n",
"exp_config.task.validation_data.global_batch_size = batch_size\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DE3ggKzzTD56"
},
"source": [
"Adjust the trainer configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "inE_-4UGkLud"
},
"outputs": [],
"source": [
"logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
"\n",
"if 'GPU' in ''.join(logical_device_names):\n",
" print('This may be broken in Colab.')\n",
" device = 'GPU'\n",
"elif 'TPU' in ''.join(logical_device_names):\n",
" print('This may be broken in Colab.')\n",
" device = 'TPU'\n",
"else:\n",
" print('Running on CPU is slow, so only train for a few steps.')\n",
" device = 'CPU'\n",
"\n",
"if device=='CPU':\n",
" train_steps = 20\n",
" exp_config.trainer.steps_per_loop = 5\n",
"else:\n",
" train_steps=5000\n",
" exp_config.trainer.steps_per_loop = 100\n",
"\n",
"exp_config.trainer.summary_interval = 100\n",
"exp_config.trainer.checkpoint_interval = train_steps\n",
"exp_config.trainer.validation_interval = 1000\n",
"exp_config.trainer.validation_steps = ds_info.splits['test'].num_examples // batch_size\n",
"exp_config.trainer.train_steps = train_steps\n",
"exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'\n",
"exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps\n",
"exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1\n",
"exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5mTcDnBiTOYD"
},
"source": [
"Print the modified configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tuVfxSBCTK-y"
},
"outputs": [],
"source": [
"pprint.pprint(exp_config.as_dict())\n",
"\n",
"display.Javascript(\"google.colab.output.setIframeHeight('300px');\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "w7_X0UHaRF2m"
},
"source": [
"Set up the distribution strategy."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ykL14FIbTaSt"
},
"outputs": [],
"source": [
"logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
"\n",
"if exp_config.runtime.mixed_precision_dtype == tf.float16:\n",
" tf.keras.mixed_precision.set_global_policy('mixed_float16')\n",
"\n",
"if 'GPU' in ''.join(logical_device_names):\n",
" distribution_strategy = tf.distribute.MirroredStrategy()\n",
"elif 'TPU' in ''.join(logical_device_names):\n",
" tf.tpu.experimental.initialize_tpu_system()\n",
" tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')\n",
" distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
"else:\n",
" print('Warning: this will be really slow.')\n",
" distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "W4k5YH5pTjaK"
},
"source": [
"Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.\n",
"\n",
"The `Task` object has all the methods necessary for building the dataset, building the model, and running training \u0026 evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6MgYSH0PtUaW"
},
"outputs": [],
"source": [
"with distribution_strategy.scope():\n",
" model_dir = tempfile.mkdtemp()\n",
" task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)\n",
"\n",
"# tf.keras.utils.plot_model(task.build_model(), show_shapes=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IFXEZYdzBKoX"
},
"outputs": [],
"source": [
"for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
" print()\n",
" print(f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')\n",
" print(f'labels.shape: {str(labels.shape):16} labels.dtype: {labels.dtype!r}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yrwxnGDaRU0U"
},
"source": [
"## Visualize the training data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "683c255c6c52"
},
"source": [
"The dataloader applies a z-score normalization using \n",
"`preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`, so the images returned by the dataset can't be directly displayed by standard tools. The visualization code needs to rescale the data into the [0,1] range."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PdmOz2EC0Nx2"
},
"outputs": [],
"source": [
"plt.hist(images.numpy().flatten());"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7a8582ebde7b"
},
"source": [
"Use `ds_info` (which is an instance of `tfds.core.DatasetInfo`) to lookup the text descriptions of each class ID."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Wq4Wq_CuDG3Q"
},
"outputs": [],
"source": [
"label_info = ds_info.features['label']\n",
"label_info.int2str(1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8c652a6fdbcf"
},
"source": [
"Visualize a batch of the data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZKfTxytf1l0d"
},
"outputs": [],
"source": [
"def show_batch(images, labels, predictions=None):\n",
" plt.figure(figsize=(10, 10))\n",
" min = images.numpy().min()\n",
" max = images.numpy().max()\n",
" delta = max - min\n",
"\n",
" for i in range(12):\n",
" plt.subplot(6, 6, i + 1)\n",
" plt.imshow((images[i]-min) / delta)\n",
" if predictions is None:\n",
" plt.title(label_info.int2str(labels[i]))\n",
" else:\n",
" if labels[i] == predictions[i]:\n",
" color = 'g'\n",
" else:\n",
" color = 'r'\n",
" plt.title(label_info.int2str(predictions[i]), color=color)\n",
" plt.axis(\"off\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xkA5h_RBtYYU"
},
"outputs": [],
"source": [
"plt.figure(figsize=(10, 10))\n",
"for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
" show_batch(images, labels)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "v_A9VnL2RbXP"
},
"source": [
"## Visualize the testing data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AXovuumW_I2z"
},
"source": [
"Visualize a batch of images from the validation dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ma-_Eb-nte9A"
},
"outputs": [],
"source": [
"plt.figure(figsize=(10, 10));\n",
"for images, labels in task.build_inputs(exp_config.task.validation_data).take(1):\n",
" show_batch(images, labels)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ihKJt2FHRi2N"
},
"source": [
"## Train and evaluate"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0AFMNvYxtjXx"
},
"outputs": [],
"source": [
"model, eval_logs = tfm.core.train_lib.run_experiment(\n",
" distribution_strategy=distribution_strategy,\n",
" task=task,\n",
" mode='train_and_eval',\n",
" params=exp_config,\n",
" model_dir=model_dir,\n",
" run_post_eval=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gCcHMQYhozmA"
},
"outputs": [],
"source": [
"# tf.keras.utils.plot_model(model, show_shapes=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "L7nVfxlBA8Gb"
},
"source": [
"Print the `accuracy`, `top_5_accuracy`, and `validation_loss` evaluation metrics."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0124f938a1b9"
},
"outputs": [],
"source": [
"for key, value in eval_logs.items():\n",
" if isinstance(value, tf.Tensor):\n",
" value = value.numpy()\n",
" print(f'{key:20}: {value:.3f}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TDys5bZ1zsml"
},
"source": [
"Run a batch of the processed training data through the model, and view the results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GhI7zR-Uz1JT"
},
"outputs": [],
"source": [
"for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
" predictions = model.predict(images)\n",
" predictions = tf.argmax(predictions, axis=-1)\n",
"\n",
"show_batch(images, labels, tf.cast(predictions, tf.int32))\n",
"\n",
"if device=='CPU':\n",
" plt.suptitle('The model was only trained for a few steps, it is not expected to do well.')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fkE9locGTBgt"
},
"source": [
"## Export a SavedModel"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9669d08c91af"
},
"source": [
"The `keras.Model` object returned by `train_lib.run_experiment` expects the data to be normalized by the dataset loader using the same mean and variance statiscics in `preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`. This export function handles those details, so you can pass `tf.uint8` images and get the correct results.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AQCFa7BvtmDg"
},
"outputs": [],
"source": [
"# Saving and exporting the trained model\n",
"export_saved_model_lib.export_inference_graph(\n",
" input_type='image_tensor',\n",
" batch_size=1,\n",
" input_image_size=[32, 32],\n",
" params=exp_config,\n",
" checkpoint_path=tf.train.latest_checkpoint(model_dir),\n",
" export_dir='./export/')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vVr6DxNqTyLZ"
},
"source": [
"Test the exported model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gP7nOvrftsB0"
},
"outputs": [],
"source": [
"# Importing SavedModel\n",
"imported = tf.saved_model.load('./export/')\n",
"model_fn = imported.signatures['serving_default']"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GiOp2WVIUNUZ"
},
"source": [
"Visualize the predictions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BTRMrZQAN4mk"
},
"outputs": [],
"source": [
"plt.figure(figsize=(10, 10))\n",
"for data in tfds.load('cifar10', split='test').batch(12).take(1):\n",
" predictions = []\n",
" for image in data['image']:\n",
" index = tf.argmax(model_fn(image[tf.newaxis, ...])['logits'], axis=1)[0]\n",
" predictions.append(index)\n",
" show_batch(data['image'], data['label'], predictions)\n",
"\n",
" if device=='CPU':\n",
" plt.suptitle('The model was only trained for a few steps, it is not expected to do better than random.')"
]
}
],
"metadata": {
"colab": {
"name": "classification_with_model_garden.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "eCes7jVU8r08"
},
"source": [
"##### Copyright 2023 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pc1j3ZVF8mmG"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SUUX9CnCYI9Y"
},
"source": [
"# Instance Segmentation with Model Garden\n",
"\n",
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/vision/instance_segmentation\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/vision/instance_segmentation.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/vision/instance_segmentation.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/vision/instance_segmentation.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UjP7bQUdTeFr"
},
"source": [
"This tutorial fine-tunes a [Mask R-CNN](https://arxiv.org/abs/1703.06870) with [Mobilenet V2](https://arxiv.org/abs/1801.04381) as backbone model from the [TensorFlow Model Garden](https://pypi.org/project/tf-models-official/) package (tensorflow-models).\n",
"\n",
"\n",
"[Model Garden](https://www.tensorflow.org/tfmodels) contains a collection of state-of-the-art models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.\n",
"\n",
"This tutorial demonstrates how to:\n",
"\n",
"1. Use models from the TensorFlow Models package.\n",
"2. Train/Fine-tune a pre-built Mask R-CNN with mobilenet as backbone for Object Detection and Instance Segmentation\n",
"3. Export the trained/tuned Mask R-CNN model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RDp6Kk1Baoi4"
},
"source": [
"## Install Necessary Dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hcl98qUOxlL8"
},
"outputs": [],
"source": [
"!pip install -U -q \"tf-models-official\"\n",
"!pip install -U -q remotezip tqdm opencv-python einops"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5-gCe_YTapey"
},
"source": [
"## Import required libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Qa9552Ukgf3d"
},
"outputs": [],
"source": [
"import os\n",
"import io\n",
"import json\n",
"import tqdm\n",
"import shutil\n",
"import pprint\n",
"import pathlib\n",
"import tempfile\n",
"import requests\n",
"import collections\n",
"import matplotlib\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from PIL import Image\n",
"from six import BytesIO\n",
"from etils import epath\n",
"from IPython import display\n",
"from urllib.request import urlopen"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tSCMIDRDP2fV"
},
"outputs": [],
"source": [
"import orbit\n",
"import tensorflow as tf\n",
"import tensorflow_models as tfm\n",
"import tensorflow_datasets as tfds\n",
"\n",
"from official.core import exp_factory\n",
"from official.core import config_definitions as cfg\n",
"from official.vision.data import tfrecord_lib\n",
"from official.vision.serving import export_saved_model_lib\n",
"from official.vision.dataloaders.tf_example_decoder import TfExampleDecoder\n",
"from official.vision.utils.object_detection import visualization_utils\n",
"from official.vision.ops.preprocess_ops import normalize_image, resize_and_crop_image\n",
"from official.vision.data.create_coco_tf_record import coco_annotations_to_lists\n",
"\n",
"pp = pprint.PrettyPrinter(indent=4) # Set Pretty Print Indentation\n",
"print(tf.__version__) # Check the version of tensorflow used\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GIrXW8sp2bKa"
},
"source": [
"## Download subset of lvis dataset\n",
"\n",
"[LVIS](https://www.tensorflow.org/datasets/catalog/lvis): A dataset for large vocabulary instance segmentation.\n",
"\n",
"Note: LVIS uses the COCO 2017 train, validation, and test image sets. \n",
"If you have already downloaded the COCO images, you only need to download \n",
"the LVIS annotations. LVIS val set contains images from COCO 2017 train in \n",
"addition to the COCO 2017 val split."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "F_A9_cS310jf"
},
"outputs": [],
"source": [
"# @title Download annotation files\n",
"\n",
"!wget https://s3-us-west-2.amazonaws.com/dl.fbaipublicfiles.com/LVIS/lvis_v1_train.json.zip\n",
"!unzip -q lvis_v1_train.json.zip\n",
"!rm lvis_v1_train.json.zip\n",
"\n",
"!wget https://s3-us-west-2.amazonaws.com/dl.fbaipublicfiles.com/LVIS/lvis_v1_val.json.zip\n",
"!unzip -q lvis_v1_val.json.zip\n",
"!rm lvis_v1_val.json.zip\n",
"\n",
"!wget https://s3-us-west-2.amazonaws.com/dl.fbaipublicfiles.com/LVIS/lvis_v1_image_info_test_dev.json.zip\n",
"!unzip -q lvis_v1_image_info_test_dev.json.zip\n",
"!rm lvis_v1_image_info_test_dev.json.zip"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "kB-C5Svj11S0"
},
"outputs": [],
"source": [
"# @title Lvis annotation parsing\n",
"\n",
"# Annotations with invalid bounding boxes. Will not be used.\n",
"_INVALID_ANNOTATIONS = [\n",
" # Train split.\n",
" 662101,\n",
" 81217,\n",
" 462924,\n",
" 227817,\n",
" 29381,\n",
" 601484,\n",
" 412185,\n",
" 504667,\n",
" 572573,\n",
" 91937,\n",
" 239022,\n",
" 181534,\n",
" 101685,\n",
" # Validation split.\n",
" 36668,\n",
" 57541,\n",
" 33126,\n",
" 10932,\n",
"]\n",
"\n",
"def get_category_map(annotation_path, num_classes):\n",
" with epath.Path(annotation_path).open() as f:\n",
" data = json.load(f)\n",
"\n",
" category_map = {id+1: {'id': cat_dict['id'],\n",
" 'name': cat_dict['name']}\n",
" for id, cat_dict in enumerate(data['categories'][:num_classes])}\n",
" return category_map\n",
"\n",
"class LvisAnnotation:\n",
" \"\"\"LVIS annotation helper class.\n",
" The format of the annations is explained on\n",
" https://www.lvisdataset.org/dataset.\n",
" \"\"\"\n",
"\n",
" def __init__(self, annotation_path):\n",
" with epath.Path(annotation_path).open() as f:\n",
" data = json.load(f)\n",
" self._data = data\n",
"\n",
" img_id2annotations = collections.defaultdict(list)\n",
" for a in self._data.get('annotations', []):\n",
" if a['category_id'] in category_ids:\n",
" img_id2annotations[a['image_id']].append(a)\n",
" self._img_id2annotations = {\n",
" k: list(sorted(v, key=lambda a: a['id']))\n",
" for k, v in img_id2annotations.items()\n",
" }\n",
"\n",
" @property\n",
" def categories(self):\n",
" \"\"\"Return the category dicts, as sorted in the file.\"\"\"\n",
" return self._data['categories']\n",
"\n",
" @property\n",
" def images(self):\n",
" \"\"\"Return the image dicts, as sorted in the file.\"\"\"\n",
" sub_images = []\n",
" for image_info in self._data['images']:\n",
" if image_info['id'] in self._img_id2annotations:\n",
" sub_images.append(image_info)\n",
" return sub_images\n",
"\n",
" def get_annotations(self, img_id):\n",
" \"\"\"Return all annotations associated with the image id string.\"\"\"\n",
" # Some images don't have any annotations. Return empty list instead.\n",
" return self._img_id2annotations.get(img_id, [])\n",
"\n",
"def _generate_tf_records(prefix, images_zip, annotation_file, num_shards=5):\n",
" \"\"\"Generate TFRecords.\"\"\"\n",
"\n",
" lvis_annotation = LvisAnnotation(annotation_file)\n",
"\n",
" def _process_example(prefix, image_info, id_to_name_map):\n",
" # Search image dirs.\n",
" filename = pathlib.Path(image_info['coco_url']).name\n",
" image = tf.io.read_file(os.path.join(IMGS_DIR, filename))\n",
" instances = lvis_annotation.get_annotations(img_id=image_info['id'])\n",
" instances = [x for x in instances if x['id'] not in _INVALID_ANNOTATIONS]\n",
" # print([x['category_id'] for x in instances])\n",
" is_crowd = {'iscrowd': 0}\n",
" instances = [dict(x, **is_crowd) for x in instances]\n",
" neg_category_ids = image_info.get('neg_category_ids', [])\n",
" not_exhaustive_category_ids = image_info.get(\n",
" 'not_exhaustive_category_ids', []\n",
" )\n",
" data, _ = coco_annotations_to_lists(instances,\n",
" id_to_name_map,\n",
" image_info['height'],\n",
" image_info['width'],\n",
" include_masks=True)\n",
" # data['category_id'] = [id-1 for id in data['category_id']]\n",
" keys_to_features = {\n",
" 'image/encoded':\n",
" tfrecord_lib.convert_to_feature(image.numpy()),\n",
" 'image/filename':\n",
" tfrecord_lib.convert_to_feature(filename.encode('utf8')),\n",
" 'image/format':\n",
" tfrecord_lib.convert_to_feature('jpg'.encode('utf8')),\n",
" 'image/height':\n",
" tfrecord_lib.convert_to_feature(image_info['height']),\n",
" 'image/width':\n",
" tfrecord_lib.convert_to_feature(image_info['width']),\n",
" 'image/source_id':\n",
" tfrecord_lib.convert_to_feature(str(image_info['id']).encode('utf8')),\n",
" 'image/object/bbox/xmin':\n",
" tfrecord_lib.convert_to_feature(data['xmin']),\n",
" 'image/object/bbox/xmax':\n",
" tfrecord_lib.convert_to_feature(data['xmax']),\n",
" 'image/object/bbox/ymin':\n",
" tfrecord_lib.convert_to_feature(data['ymin']),\n",
" 'image/object/bbox/ymax':\n",
" tfrecord_lib.convert_to_feature(data['ymax']),\n",
" 'image/object/class/text':\n",
" tfrecord_lib.convert_to_feature(data['category_names']),\n",
" 'image/object/class/label':\n",
" tfrecord_lib.convert_to_feature(data['category_id']),\n",
" 'image/object/is_crowd':\n",
" tfrecord_lib.convert_to_feature(data['is_crowd']),\n",
" 'image/object/area':\n",
" tfrecord_lib.convert_to_feature(data['area'], 'float_list'),\n",
" 'image/object/mask':\n",
" tfrecord_lib.convert_to_feature(data['encoded_mask_png'])\n",
" }\n",
" # print(keys_to_features['image/object/class/label'])\n",
" example = tf.train.Example(\n",
" features=tf.train.Features(feature=keys_to_features))\n",
" return example\n",
"\n",
"\n",
"\n",
" # file_names = [f\"{prefix}/{pathlib.Path(image_info['coco_url']).name}\"\n",
" # for image_info in lvis_annotation.images]\n",
" # _extract_images(images_zip, file_names)\n",
" writers = [\n",
" tf.io.TFRecordWriter(\n",
" tf_records_dir + prefix +'-%05d-of-%05d.tfrecord' % (i, num_shards))\n",
" for i in range(num_shards)\n",
" ]\n",
" id_to_name_map = {cat_dict['id']: cat_dict['name']\n",
" for cat_dict in lvis_annotation.categories[:NUM_CLASSES]}\n",
" # print(id_to_name_map)\n",
" for idx, image_info in enumerate(tqdm.tqdm(lvis_annotation.images)):\n",
" img_data = requests.get(image_info['coco_url'], stream=True).content\n",
" img_name = image_info['coco_url'].split('/')[-1]\n",
" with open(os.path.join(IMGS_DIR, img_name), 'wb') as handler:\n",
" handler.write(img_data)\n",
" tf_example = _process_example(prefix, image_info, id_to_name_map)\n",
" writers[idx % num_shards].write(tf_example.SerializeToString())\n",
"\n",
" del lvis_annotation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5u2dwjIT2HZu"
},
"outputs": [],
"source": [
"_URLS = {\n",
" 'train_images': 'http://images.cocodataset.org/zips/train2017.zip',\n",
" 'validation_images': 'http://images.cocodataset.org/zips/val2017.zip',\n",
" 'test_images': 'http://images.cocodataset.org/zips/test2017.zip',\n",
"}\n",
"\n",
"train_prefix = 'train'\n",
"valid_prefix = 'val'\n",
"\n",
"train_annotation_path = './lvis_v1_train.json'\n",
"valid_annotation_path = './lvis_v1_val.json'\n",
"\n",
"IMGS_DIR = './lvis_sub_dataset/'\n",
"tf_records_dir = './lvis_tfrecords/'\n",
"\n",
"\n",
"if not os.path.exists(IMGS_DIR):\n",
" os.mkdir(IMGS_DIR)\n",
"\n",
"if not os.path.exists(tf_records_dir):\n",
" os.mkdir(tf_records_dir)\n",
"\n",
"\n",
"\n",
"NUM_CLASSES = 3\n",
"category_index = get_category_map(valid_annotation_path, NUM_CLASSES)\n",
"category_ids = list(category_index.keys())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KBgl5fG42LpD"
},
"outputs": [],
"source": [
"# Below helper function are taken from github tensorflow dataset lvis\n",
"# https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/datasets/lvis/lvis_dataset_builder.py\n",
"_generate_tf_records(train_prefix,\n",
" _URLS['train_images'],\n",
" train_annotation_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "89O59u_H2NIJ"
},
"outputs": [],
"source": [
"_generate_tf_records(valid_prefix,\n",
" _URLS['validation_images'],\n",
" valid_annotation_path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EREyevfIY4rz"
},
"source": [
"## Configure the MaskRCNN Resnet FPN COCO model for custom dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "5yGLLvXlPInP"
},
"outputs": [],
"source": [
"train_data_input_path = './lvis_tfrecords/train*'\n",
"valid_data_input_path = './lvis_tfrecords/val*'\n",
"test_data_input_path = './lvis_tfrecords/test*'\n",
"model_dir = './trained_model/'\n",
"export_dir ='./exported_model/'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ms3wRQKAIORe"
},
"outputs": [],
"source": [
"if not os.path.exists(model_dir):\n",
" os.mkdir(model_dir)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EXA5NmvDblYP"
},
"source": [
"In Model Garden, the collections of parameters that define a model are called *configs*. Model Garden can create a config based on a known set of parameters via a [factory](https://en.wikipedia.org/wiki/Factory_method_pattern).\n",
"\n",
"\n",
"Use the `retinanet_mobilenet_coco` experiment configuration, as defined by `tfm.vision.configs.maskrcnn.maskrcnn_mobilenet_coco`.\n",
"\n",
"Please find all the registered experiements [here](https://www.tensorflow.org/api_docs/python/tfm/core/exp_factory/get_exp_config)\n",
"\n",
"The configuration defines an experiment to train a Mask R-CNN model with mobilenet as backbone and FPN as decoder. Default Congiguration is trained on [COCO](https://cocodataset.org/) train2017 and evaluated on [COCO](https://cocodataset.org/) val2017.\n",
"\n",
"There are also other alternative experiments available such as\n",
"`maskrcnn_resnetfpn_coco`,\n",
"`maskrcnn_spinenet_coco` and more. One can switch to them by changing the experiment name argument to the `get_exp_config` function."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Zi2F1qGgPWOH"
},
"outputs": [],
"source": [
"exp_config = exp_factory.get_exp_config('maskrcnn_mobilenet_coco')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zo-EaCdmn5j-"
},
"outputs": [],
"source": [
"model_ckpt_path = './model_ckpt/'\n",
"if not os.path.exists(model_ckpt_path):\n",
" os.mkdir(model_ckpt_path)\n",
"\n",
"!gsutil cp gs://tf_model_garden/vision/mobilenet/v2_1.0_float/ckpt-180648.data-00000-of-00001 './model_ckpt/'\n",
"!gsutil cp gs://tf_model_garden/vision/mobilenet/v2_1.0_float/ckpt-180648.index './model_ckpt/'"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ymnwJYaFgHs2"
},
"source": [
"### Adjust the model and dataset configurations so that it works with custom dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zyn9ieZyUbEJ"
},
"outputs": [],
"source": [
"BATCH_SIZE = 8\n",
"HEIGHT, WIDTH = 256, 256\n",
"IMG_SHAPE = [HEIGHT, WIDTH, 3]\n",
"\n",
"\n",
"# Backbone Config\n",
"exp_config.task.annotation_file = None\n",
"exp_config.task.freeze_backbone = True\n",
"exp_config.task.init_checkpoint = \"./model_ckpt/ckpt-180648\"\n",
"exp_config.task.init_checkpoint_modules = \"backbone\"\n",
"\n",
"# Model Config\n",
"exp_config.task.model.num_classes = NUM_CLASSES + 1\n",
"exp_config.task.model.input_size = IMG_SHAPE\n",
"\n",
"# Training Data Config\n",
"exp_config.task.train_data.input_path = train_data_input_path\n",
"exp_config.task.train_data.dtype = 'float32'\n",
"exp_config.task.train_data.global_batch_size = BATCH_SIZE\n",
"exp_config.task.train_data.shuffle_buffer_size = 64\n",
"exp_config.task.train_data.parser.aug_scale_max = 1.0\n",
"exp_config.task.train_data.parser.aug_scale_min = 1.0\n",
"\n",
"# Validation Data Config\n",
"exp_config.task.validation_data.input_path = valid_data_input_path\n",
"exp_config.task.validation_data.dtype = 'float32'\n",
"exp_config.task.validation_data.global_batch_size = BATCH_SIZE"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0409ReANgKzF"
},
"source": [
"### Adjust the trainer configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ne8t5AHRUd9g"
},
"outputs": [],
"source": [
"logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
"\n",
"if 'GPU' in ''.join(logical_device_names):\n",
" print('This may be broken in Colab.')\n",
" device = 'GPU'\n",
"elif 'TPU' in ''.join(logical_device_names):\n",
" print('This may be broken in Colab.')\n",
" device = 'TPU'\n",
"else:\n",
" print('Running on CPU is slow, so only train for a few steps.')\n",
" device = 'CPU'\n",
"\n",
"\n",
"train_steps = 2000\n",
"exp_config.trainer.steps_per_loop = 200 # steps_per_loop = num_of_training_examples // train_batch_size\n",
"\n",
"exp_config.trainer.summary_interval = 200\n",
"exp_config.trainer.checkpoint_interval = 200\n",
"exp_config.trainer.validation_interval = 200\n",
"exp_config.trainer.validation_steps = 200 # validation_steps = num_of_validation_examples // eval_batch_size\n",
"exp_config.trainer.train_steps = train_steps\n",
"exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 200\n",
"exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'\n",
"exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps\n",
"exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.07\n",
"exp_config.trainer.optimizer_config.warmup.linear.warmup_learning_rate = 0.05"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "k3I4X-bWgNm0"
},
"source": [
"### Print the modified configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IsmxXNlyWBAK"
},
"outputs": [],
"source": [
"pp.pprint(exp_config.as_dict())\n",
"display.Javascript(\"google.colab.output.setIframeHeight('500px');\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jxarWEHDgQSk"
},
"source": [
"### Set up the distribution strategy."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4JxhiGNwQRv2"
},
"outputs": [],
"source": [
"# Setting up the Strategy\n",
"if exp_config.runtime.mixed_precision_dtype == tf.float16:\n",
" tf.keras.mixed_precision.set_global_policy('mixed_float16')\n",
"\n",
"if 'GPU' in ''.join(logical_device_names):\n",
" distribution_strategy = tf.distribute.MirroredStrategy()\n",
"elif 'TPU' in ''.join(logical_device_names):\n",
" tf.tpu.experimental.initialize_tpu_system()\n",
" tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')\n",
" distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
"else:\n",
" print('Warning: this will be really slow.')\n",
" distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])\n",
"\n",
"print(\"Done\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QqZU9f1ugS_A"
},
"source": [
"## Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.\n",
"\n",
"The `Task` object has all the methods necessary for building the dataset, building the model, and running training \u0026 evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "N5R-7KzORB1n"
},
"outputs": [],
"source": [
"with distribution_strategy.scope():\n",
" task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Fmpz2R_cglIv"
},
"source": [
"## Visualize a batch of the data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "O82f_7A8gfnY"
},
"outputs": [],
"source": [
"for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
" print()\n",
" print(f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')\n",
" print(f'labels.keys: {labels.keys()}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dLcSHWjqgl66"
},
"source": [
"### Create Category Index Dictionary to map the labels to coressponding label names"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ajF85r_6R9d9"
},
"outputs": [],
"source": [
"tf_ex_decoder = TfExampleDecoder(include_mask=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gRdveeYVgr7B"
},
"source": [
"### Helper Function for Visualizing the results from TFRecords\n",
"Use `visualize_boxes_and_labels_on_image_array` from `visualization_utils` to draw boudning boxes on the image."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uWEuOs8QStrz"
},
"outputs": [],
"source": [
"def show_batch(raw_records, num_of_examples):\n",
" plt.figure(figsize=(20, 20))\n",
" use_normalized_coordinates=True\n",
" min_score_thresh = 0.30\n",
" for i, serialized_example in enumerate(raw_records):\n",
" plt.subplot(1, 3, i + 1)\n",
" decoded_tensors = tf_ex_decoder.decode(serialized_example)\n",
" image = decoded_tensors['image'].numpy().astype('uint8')\n",
" scores = np.ones(shape=(len(decoded_tensors['groundtruth_boxes'])))\n",
" # print(decoded_tensors['groundtruth_instance_masks'].numpy().shape)\n",
" # print(decoded_tensors.keys())\n",
" visualization_utils.visualize_boxes_and_labels_on_image_array(\n",
" image,\n",
" decoded_tensors['groundtruth_boxes'].numpy(),\n",
" decoded_tensors['groundtruth_classes'].numpy().astype('int'),\n",
" scores,\n",
" category_index=category_index,\n",
" use_normalized_coordinates=use_normalized_coordinates,\n",
" min_score_thresh=min_score_thresh,\n",
" instance_masks=decoded_tensors['groundtruth_instance_masks'].numpy().astype('uint8'),\n",
" line_thickness=4)\n",
"\n",
" plt.imshow(image)\n",
" plt.axis(\"off\")\n",
" plt.title(f\"Image-{i+1}\")\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FergQ2P5gv_j"
},
"source": [
"### Visualization of Train Data\n",
"\n",
"The bounding box detection has three components\n",
" 1. Class label of the object detected.\n",
" 2. Percentage of match between predicted and ground truth bounding boxes.\n",
" 3. Instance Segmentation Mask\n",
"\n",
"**Note**: The reason of everything is 100% is because we are visualising the groundtruth"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lN0zdBwxU5Z5"
},
"outputs": [],
"source": [
"buffer_size = 100\n",
"num_of_examples = 3\n",
"\n",
"train_tfrecords = tf.io.gfile.glob(exp_config.task.train_data.input_path)\n",
"raw_records = tf.data.TFRecordDataset(train_tfrecords).shuffle(buffer_size=buffer_size).take(num_of_examples)\n",
"show_batch(raw_records, num_of_examples)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nn7IZSs5hQLg"
},
"source": [
"## Train and evaluate\n",
"\n",
"We follow the COCO challenge tradition to evaluate the accuracy of object detection based on mAP(mean Average Precision). Please check [here](https://cocodataset.org/#detection-eval) for detail explanation of how evaluation metrics for detection task is done.\n",
"\n",
"**IoU**: is defined as the area of the intersection divided by the area of the union of a predicted bounding box and ground truth bounding box."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UTuIs4kFZGv_"
},
"outputs": [],
"source": [
"model, eval_logs = tfm.core.train_lib.run_experiment(\n",
" distribution_strategy=distribution_strategy,\n",
" task=task,\n",
" mode='train_and_eval',\n",
" params=exp_config,\n",
" model_dir=model_dir,\n",
" run_post_eval=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rfpH4QHkh1gI"
},
"source": [
"## Load logs in tensorboard"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wcdOvg6eNP6R"
},
"outputs": [],
"source": [
"%load_ext tensorboard\n",
"%tensorboard --logdir \"./trained_model\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hAo9lozJh2cV"
},
"source": [
"## Saving and exporting the trained model\n",
"\n",
"The `keras.Model` object returned by `train_lib.run_experiment` expects the data to be normalized by the dataset loader using the same mean and variance statiscics in `preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`. This export function handles those details, so you can pass `tf.uint8` images and get the correct results."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iZG1vPbTQqFh"
},
"outputs": [],
"source": [
"export_saved_model_lib.export_inference_graph(\n",
" input_type='image_tensor',\n",
" batch_size=1,\n",
" input_image_size=[HEIGHT, WIDTH],\n",
" params=exp_config,\n",
" checkpoint_path=tf.train.latest_checkpoint(model_dir),\n",
" export_dir=export_dir)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OHIfMeVXh7vJ"
},
"source": [
"## Inference from Trained Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uaXyzMvXROTd"
},
"outputs": [],
"source": [
"def load_image_into_numpy_array(path):\n",
" \"\"\"Load an image from file into a numpy array.\n",
"\n",
" Puts image into numpy array to feed into tensorflow graph.\n",
" Note that by convention we put it into a numpy array with shape\n",
" (height, width, channels), where channels=3 for RGB.\n",
"\n",
" Args:\n",
" path: the file path to the image\n",
"\n",
" Returns:\n",
" uint8 numpy array with shape (img_height, img_width, 3)\n",
" \"\"\"\n",
" image = None\n",
" if(path.startswith('http')):\n",
" response = urlopen(path)\n",
" image_data = response.read()\n",
" image_data = BytesIO(image_data)\n",
" image = Image.open(image_data)\n",
" else:\n",
" image_data = tf.io.gfile.GFile(path, 'rb').read()\n",
" image = Image.open(BytesIO(image_data))\n",
"\n",
" (im_width, im_height) = image.size\n",
" return np.array(image.getdata()).reshape(\n",
" (1, im_height, im_width, 3)).astype(np.uint8)\n",
"\n",
"\n",
"\n",
"def build_inputs_for_object_detection(image, input_image_size):\n",
" \"\"\"Builds Object Detection model inputs for serving.\"\"\"\n",
" image, _ = resize_and_crop_image(\n",
" image,\n",
" input_image_size,\n",
" padded_size=input_image_size,\n",
" aug_scale_min=1.0,\n",
" aug_scale_max=1.0)\n",
" return image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZDI9zv_4h-7-"
},
"source": [
"## Visualize test data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rdyIri-1RThk"
},
"outputs": [],
"source": [
"num_of_examples = 3\n",
"\n",
"test_tfrecords = tf.io.gfile.glob('./lvis_tfrecords/val*')\n",
"test_ds = tf.data.TFRecordDataset(test_tfrecords).take(num_of_examples)\n",
"show_batch(test_ds, num_of_examples)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KkMZm4DtiAHO"
},
"source": [
"## Importing SavedModel"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rDozz4NXRZ7p"
},
"outputs": [],
"source": [
"imported = tf.saved_model.load(export_dir)\n",
"model_fn = imported.signatures['serving_default']"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DUxk4-AjLAcO"
},
"source": [
"## Visualize predictions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Gez57T5ShYnM"
},
"outputs": [],
"source": [
"def reframe_image_corners_relative_to_boxes(boxes):\n",
" \"\"\"Reframe the image corners ([0, 0, 1, 1]) to be relative to boxes.\n",
" The local coordinate frame of each box is assumed to be relative to\n",
" its own for corners.\n",
" Args:\n",
" boxes: A float tensor of [num_boxes, 4] of (ymin, xmin, ymax, xmax)\n",
" coordinates in relative coordinate space of each bounding box.\n",
" Returns:\n",
" reframed_boxes: Reframes boxes with same shape as input.\n",
" \"\"\"\n",
" ymin, xmin, ymax, xmax = (boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3])\n",
"\n",
" height = tf.maximum(ymax - ymin, 1e-4)\n",
" width = tf.maximum(xmax - xmin, 1e-4)\n",
"\n",
" ymin_out = (0 - ymin) / height\n",
" xmin_out = (0 - xmin) / width\n",
" ymax_out = (1 - ymin) / height\n",
" xmax_out = (1 - xmin) / width\n",
" return tf.stack([ymin_out, xmin_out, ymax_out, xmax_out], axis=1)\n",
"\n",
"def reframe_box_masks_to_image_masks(box_masks, boxes, image_height,\n",
" image_width, resize_method='bilinear'):\n",
" \"\"\"Transforms the box masks back to full image masks.\n",
" Embeds masks in bounding boxes of larger masks whose shapes correspond to\n",
" image shape.\n",
" Args:\n",
" box_masks: A tensor of size [num_masks, mask_height, mask_width].\n",
" boxes: A tf.float32 tensor of size [num_masks, 4] containing the box\n",
" corners. Row i contains [ymin, xmin, ymax, xmax] of the box\n",
" corresponding to mask i. Note that the box corners are in\n",
" normalized coordinates.\n",
" image_height: Image height. The output mask will have the same height as\n",
" the image height.\n",
" image_width: Image width. The output mask will have the same width as the\n",
" image width.\n",
" resize_method: The resize method, either 'bilinear' or 'nearest'. Note that\n",
" 'bilinear' is only respected if box_masks is a float.\n",
" Returns:\n",
" A tensor of size [num_masks, image_height, image_width] with the same dtype\n",
" as `box_masks`.\n",
" \"\"\"\n",
" resize_method = 'nearest' if box_masks.dtype == tf.uint8 else resize_method\n",
" # TODO(rathodv): Make this a public function.\n",
" def reframe_box_masks_to_image_masks_default():\n",
" \"\"\"The default function when there are more than 0 box masks.\"\"\"\n",
"\n",
" num_boxes = tf.shape(box_masks)[0]\n",
" box_masks_expanded = tf.expand_dims(box_masks, axis=3)\n",
"\n",
" resized_crops = tf.image.crop_and_resize(\n",
" image=box_masks_expanded,\n",
" boxes=reframe_image_corners_relative_to_boxes(boxes),\n",
" box_indices=tf.range(num_boxes),\n",
" crop_size=[image_height, image_width],\n",
" method=resize_method,\n",
" extrapolation_value=0)\n",
" return tf.cast(resized_crops, box_masks.dtype)\n",
"\n",
" image_masks = tf.cond(\n",
" tf.shape(box_masks)[0] \u003e 0,\n",
" reframe_box_masks_to_image_masks_default,\n",
" lambda: tf.zeros([0, image_height, image_width, 1], box_masks.dtype))\n",
" return tf.squeeze(image_masks, axis=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6EIRAlXcSQaA"
},
"outputs": [],
"source": [
"input_image_size = (HEIGHT, WIDTH)\n",
"plt.figure(figsize=(20, 20))\n",
"min_score_thresh = 0.40 # Change minimum score for threshold to see all bounding boxes confidences\n",
"\n",
"for i, serialized_example in enumerate(test_ds):\n",
" plt.subplot(1, 3, i+1)\n",
" decoded_tensors = tf_ex_decoder.decode(serialized_example)\n",
" image = build_inputs_for_object_detection(decoded_tensors['image'], input_image_size)\n",
" image = tf.expand_dims(image, axis=0)\n",
" image = tf.cast(image, dtype = tf.uint8)\n",
" image_np = image[0].numpy()\n",
" result = model_fn(image)\n",
" # Visualize detection and masks\n",
" if 'detection_masks' in result:\n",
" # we need to convert np.arrays to tensors\n",
" detection_masks = tf.convert_to_tensor(result['detection_masks'][0])\n",
" detection_boxes = tf.convert_to_tensor(result['detection_boxes'][0])\n",
" detection_masks_reframed = reframe_box_masks_to_image_masks(\n",
" detection_masks, detection_boxes/255.0,\n",
" image_np.shape[0], image_np.shape[1])\n",
" detection_masks_reframed = tf.cast(\n",
" detection_masks_reframed \u003e min_score_thresh,\n",
" np.uint8)\n",
"\n",
" result['detection_masks_reframed'] = detection_masks_reframed.numpy()\n",
" visualization_utils.visualize_boxes_and_labels_on_image_array(\n",
" image_np,\n",
" result['detection_boxes'][0].numpy(),\n",
" (result['detection_classes'][0] + 0).numpy().astype(int),\n",
" result['detection_scores'][0].numpy(),\n",
" category_index=category_index,\n",
" use_normalized_coordinates=False,\n",
" max_boxes_to_draw=200,\n",
" min_score_thresh=min_score_thresh,\n",
" instance_masks=result.get('detection_masks_reframed', None),\n",
" line_thickness=4)\n",
"\n",
" plt.imshow(image_np)\n",
" plt.axis(\"off\")\n",
"\n",
"plt.show()"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "instance_segmentation.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "Cayt5nCXb3WG"
},
"source": [
"##### Copyright 2022 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DYL3CXHRb9-f"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VYDmsvURYZjz"
},
"source": [
"# Object detection with Model Garden\n",
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/vision/object_detection\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/vision/object_detection.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/vision/object_detection.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/vision/object_detection.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "69aQq_PXcUvL"
},
"source": [
"This tutorial fine-tunes a [RetinaNet](https://arxiv.org/abs/1708.02002) with ResNet-50 as backbone model from the [TensorFlow Model Garden](https://pypi.org/project/tf-models-official/) package (tensorflow-models) to detect three different Blood Cells in [BCCD](https://public.roboflow.com/object-detection/bccd) dataset. The RetinaNet is pretrained on [COCO](https://cocodataset.org/) train2017 and evaluated on [COCO](https://cocodataset.org/) val2017\n",
"\n",
"[Model Garden](https://www.tensorflow.org/tfmodels) contains a collection of state-of-the-art models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.\n",
"\n",
"This tutorial demonstrates how to:\n",
"\n",
"1. Use models from the Tensorflow Model Garden(TFM) package.\n",
"2. Fine-tune a pre-trained RetinanNet with ResNet-50 as backbone for object detection.\n",
"3. Export the tuned RetinaNet model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IeSHlZyUZl6f"
},
"source": [
"## Install necessary dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Pip0LHj3ZqgL"
},
"outputs": [],
"source": [
"!pip install -U -q \"tensorflow\" \"tf-models-official\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "H3kS7Y0sZsIj"
},
"source": [
"## Import required libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hFdVelJ2YbQz"
},
"outputs": [],
"source": [
"import os\n",
"import io\n",
"import pprint\n",
"import tempfile\n",
"import matplotlib\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from PIL import Image\n",
"from six import BytesIO\n",
"from IPython import display\n",
"from urllib.request import urlopen"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TF77J-iMZn_u"
},
"source": [
"## Import required libraries from tensorflow models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iT27_SOTY1Dz"
},
"outputs": [],
"source": [
"import orbit\n",
"import tensorflow_models as tfm\n",
"\n",
"from official.core import exp_factory\n",
"from official.core import config_definitions as cfg\n",
"from official.vision.serving import export_saved_model_lib\n",
"from official.vision.ops.preprocess_ops import normalize_image\n",
"from official.vision.ops.preprocess_ops import resize_and_crop_image\n",
"from official.vision.utils.object_detection import visualization_utils\n",
"from official.vision.dataloaders.tf_example_decoder import TfExampleDecoder\n",
"\n",
"pp = pprint.PrettyPrinter(indent=4) # Set Pretty Print Indentation\n",
"print(tf.__version__) # Check the version of tensorflow used\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WGbMG8cpZyKa"
},
"source": [
"## Custom dataset preparation for object detection\n",
"\n",
"Models in official repository(of model-garden) requires data in a TFRecords format.\n",
"\n",
"\n",
"Please check [this resource](https://www.tensorflow.org/tutorials/load_data/tfrecord) to learn more about TFRecords data format.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YcFW-xHRZ1xJ"
},
"source": [
"### clone the model-garden repo as the required data conversion codes are within this model-garden repository"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tZLidUjiY1xt"
},
"outputs": [],
"source": [
"!git clone --quiet https://github.com/tensorflow/models.git"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Uq5hcbJ8Z4th"
},
"source": [
"### Upload your custom data in drive or local disk of the notebook and unzip the data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rDixpoqoY3Za"
},
"outputs": [],
"source": [
"!curl -L 'https://public.roboflow.com/ds/ZpYLqHeT0W?key=ZXfZLRnhsc' \u003e './BCCD.v1-bccd.coco.zip'\n",
"!unzip -q -o './BCCD.v1-bccd.coco.zip' -d './BCC.v1-bccd.coco/'\n",
"!rm './BCCD.v1-bccd.coco.zip'"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jKJ3MtgeZ5om"
},
"source": [
"### Change directory to vision or data where data conversion tools are available"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IYM7PIFbY5EL"
},
"outputs": [],
"source": [
"%cd ./models/"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GI1h9UChZ8cC"
},
"source": [
"### CLI command to convert data(train data)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "x_8cmB82Y65O"
},
"outputs": [],
"source": [
"%%bash\n",
"\n",
"TRAIN_DATA_DIR='../BCC.v1-bccd.coco/train'\n",
"TRAIN_ANNOTATION_FILE_DIR='../BCC.v1-bccd.coco/train/_annotations.coco.json'\n",
"OUTPUT_TFRECORD_TRAIN='../bccd_coco_tfrecords/train'\n",
"\n",
"# Need to provide\n",
" # 1. image_dir: where images are present\n",
" # 2. object_annotations_file: where annotations are listed in json format\n",
" # 3. output_file_prefix: where to write output convered TFRecords files\n",
"python -m official.vision.data.create_coco_tf_record --logtostderr \\\n",
" --image_dir=${TRAIN_DATA_DIR} \\\n",
" --object_annotations_file=${TRAIN_ANNOTATION_FILE_DIR} \\\n",
" --output_file_prefix=$OUTPUT_TFRECORD_TRAIN \\\n",
" --num_shards=1"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VuwZpwUoaAKU"
},
"source": [
"### CLI command to convert data(validation data)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "q8mQ8prGY8kh"
},
"outputs": [],
"source": [
"%%bash\n",
"\n",
"VALID_DATA_DIR='../BCC.v1-bccd.coco/valid'\n",
"VALID_ANNOTATION_FILE_DIR='../BCC.v1-bccd.coco/valid/_annotations.coco.json'\n",
"OUTPUT_TFRECORD_VALID='../bccd_coco_tfrecords/valid'\n",
"\n",
"python -m official.vision.data.create_coco_tf_record --logtostderr \\\n",
" --image_dir=$VALID_DATA_DIR \\\n",
" --object_annotations_file=$VALID_ANNOTATION_FILE_DIR \\\n",
" --output_file_prefix=$OUTPUT_TFRECORD_VALID \\\n",
" --num_shards=1"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BYGxNNAXaCW6"
},
"source": [
"### CLI command to convert data(test data)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-K8qlfstY-Ua"
},
"outputs": [],
"source": [
"%%bash\n",
"\n",
"TEST_DATA_DIR='../BCC.v1-bccd.coco/test'\n",
"TEST_ANNOTATION_FILE_DIR='../BCC.v1-bccd.coco/test/_annotations.coco.json'\n",
"OUTPUT_TFRECORD_TEST='../bccd_coco_tfrecords/test'\n",
"\n",
"python -m official.vision.data.create_coco_tf_record --logtostderr \\\n",
" --image_dir=$TEST_DATA_DIR \\\n",
" --object_annotations_file=$TEST_ANNOTATION_FILE_DIR \\\n",
" --output_file_prefix=$OUTPUT_TFRECORD_TEST \\\n",
" --num_shards=1"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cW7hQEJTaEtj"
},
"source": [
"## Configure the Retinanet Resnet FPN COCO model for custom dataset.\n",
"\n",
"Dataset used for fine tuning the checkpoint is Blood Cells Detection (BCCD)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PMGEl7iXZAAF"
},
"outputs": [],
"source": [
"train_data_input_path = '../bccd_coco_tfrecords/train-00000-of-00001.tfrecord'\n",
"valid_data_input_path = '../bccd_coco_tfrecords/valid-00000-of-00001.tfrecord'\n",
"test_data_input_path = '../bccd_coco_tfrecords/test-00000-of-00001.tfrecord'\n",
"model_dir = '../trained_model/'\n",
"export_dir ='../exported_model/'"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2DJpKvdeaHF3"
},
"source": [
"In Model Garden, the collections of parameters that define a model are called *configs*. Model Garden can create a config based on a known set of parameters via a [factory](https://en.wikipedia.org/wiki/Factory_method_pattern).\n",
"\n",
"\n",
"Use the `retinanet_resnetfpn_coco` experiment configuration, as defined by `tfm.vision.configs.retinanet.retinanet_resnetfpn_coco`.\n",
"\n",
"The configuration defines an experiment to train a RetinanNet with Resnet-50 as backbone, FPN as decoder. Default Configuration is trained on [COCO](https://cocodataset.org/) train2017 and evaluated on [COCO](https://cocodataset.org/) val2017.\n",
"\n",
"There are also other alternative experiments available such as\n",
"`retinanet_resnetfpn_coco`, `retinanet_spinenet_coco`, `fasterrcnn_resnetfpn_coco` and more. One can switch to them by changing the experiment name argument to the `get_exp_config` function.\n",
"\n",
"We are going to fine tune the Resnet-50 backbone checkpoint which is already present in the default configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ie1ObPH9ZBpa"
},
"outputs": [],
"source": [
"exp_config = exp_factory.get_exp_config('retinanet_resnetfpn_coco')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LFhjFkw-alba"
},
"source": [
"### Adjust the model and dataset configurations so that it works with custom dataset(in this case `BCCD`)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ej7j6dvIZDQA"
},
"outputs": [],
"source": [
"batch_size = 8\n",
"num_classes = 3\n",
"\n",
"HEIGHT, WIDTH = 256, 256\n",
"IMG_SIZE = [HEIGHT, WIDTH, 3]\n",
"\n",
"# Backbone config.\n",
"exp_config.task.freeze_backbone = False\n",
"exp_config.task.annotation_file = ''\n",
"\n",
"# Model config.\n",
"exp_config.task.model.input_size = IMG_SIZE\n",
"exp_config.task.model.num_classes = num_classes + 1\n",
"exp_config.task.model.detection_generator.tflite_post_processing.max_classes_per_detection = exp_config.task.model.num_classes\n",
"\n",
"# Training data config.\n",
"exp_config.task.train_data.input_path = train_data_input_path\n",
"exp_config.task.train_data.dtype = 'float32'\n",
"exp_config.task.train_data.global_batch_size = batch_size\n",
"exp_config.task.train_data.parser.aug_scale_max = 1.0\n",
"exp_config.task.train_data.parser.aug_scale_min = 1.0\n",
"\n",
"# Validation data config.\n",
"exp_config.task.validation_data.input_path = valid_data_input_path\n",
"exp_config.task.validation_data.dtype = 'float32'\n",
"exp_config.task.validation_data.global_batch_size = batch_size"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ROVc1rayaqI1"
},
"source": [
"### Adjust the trainer configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BZsCVBafZFIE"
},
"outputs": [],
"source": [
"logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]\n",
"\n",
"if 'GPU' in ''.join(logical_device_names):\n",
" print('This may be broken in Colab.')\n",
" device = 'GPU'\n",
"elif 'TPU' in ''.join(logical_device_names):\n",
" print('This may be broken in Colab.')\n",
" device = 'TPU'\n",
"else:\n",
" print('Running on CPU is slow, so only train for a few steps.')\n",
" device = 'CPU'\n",
"\n",
"\n",
"train_steps = 1000\n",
"exp_config.trainer.steps_per_loop = 100 # steps_per_loop = num_of_training_examples // train_batch_size\n",
"\n",
"exp_config.trainer.summary_interval = 100\n",
"exp_config.trainer.checkpoint_interval = 100\n",
"exp_config.trainer.validation_interval = 100\n",
"exp_config.trainer.validation_steps = 100 # validation_steps = num_of_validation_examples // eval_batch_size\n",
"exp_config.trainer.train_steps = train_steps\n",
"exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100\n",
"exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'\n",
"exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps\n",
"exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1\n",
"exp_config.trainer.optimizer_config.warmup.linear.warmup_learning_rate = 0.05"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XS6cJfs2atgI"
},
"source": [
"### Print the modified configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IvfJlqI7ZIcD"
},
"outputs": [],
"source": [
"pp.pprint(exp_config.as_dict())\n",
"display.Javascript('google.colab.output.setIframeHeight(\"500px\");')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6o5mbpRBawbs"
},
"source": [
"### Set up the distribution strategy."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2NvY8QHOZKGr"
},
"outputs": [],
"source": [
"if exp_config.runtime.mixed_precision_dtype == tf.float16:\n",
" tf.keras.mixed_precision.set_global_policy('mixed_float16')\n",
"\n",
"if 'GPU' in ''.join(logical_device_names):\n",
" distribution_strategy = tf.distribute.MirroredStrategy()\n",
"elif 'TPU' in ''.join(logical_device_names):\n",
" tf.tpu.experimental.initialize_tpu_system()\n",
" tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')\n",
" distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
"else:\n",
" print('Warning: this will be really slow.')\n",
" distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])\n",
"\n",
"print('Done')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4wPtJgoOa33v"
},
"source": [
"## Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.\n",
"\n",
"The `Task` object has all the methods necessary for building the dataset, building the model, and running training \u0026 evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ns9LAsiXZLuX"
},
"outputs": [],
"source": [
"with distribution_strategy.scope():\n",
" task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vTKbQxDkbArE"
},
"source": [
"## Visualize a batch of the data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3RIlbhj0ZNvt"
},
"outputs": [],
"source": [
"for images, labels in task.build_inputs(exp_config.task.train_data).take(1):\n",
" print()\n",
" print(f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')\n",
" print(f'labels.keys: {labels.keys()}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m-QW7DoKbD8z"
},
"source": [
"### Create category index dictionary to map the labels to coressponding label names."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MN0sSthbZR-s"
},
"outputs": [],
"source": [
"category_index={\n",
" 1: {\n",
" 'id': 1,\n",
" 'name': 'Platelets'\n",
" },\n",
" 2: {\n",
" 'id': 2,\n",
" 'name': 'RBC'\n",
" },\n",
" 3: {\n",
" 'id': 3,\n",
" 'name': 'WBC'\n",
" }\n",
"}\n",
"tf_ex_decoder = TfExampleDecoder()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AcbmD1pRbGcS"
},
"source": [
"### Helper function for visualizing the results from TFRecords.\n",
"Use `visualize_boxes_and_labels_on_image_array` from `visualization_utils` to draw boudning boxes on the image."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wWBeomMMZThI"
},
"outputs": [],
"source": [
"def show_batch(raw_records, num_of_examples):\n",
" plt.figure(figsize=(20, 20))\n",
" use_normalized_coordinates=True\n",
" min_score_thresh = 0.30\n",
" for i, serialized_example in enumerate(raw_records):\n",
" plt.subplot(1, 3, i + 1)\n",
" decoded_tensors = tf_ex_decoder.decode(serialized_example)\n",
" image = decoded_tensors['image'].numpy().astype('uint8')\n",
" scores = np.ones(shape=(len(decoded_tensors['groundtruth_boxes'])))\n",
" visualization_utils.visualize_boxes_and_labels_on_image_array(\n",
" image,\n",
" decoded_tensors['groundtruth_boxes'].numpy(),\n",
" decoded_tensors['groundtruth_classes'].numpy().astype('int'),\n",
" scores,\n",
" category_index=category_index,\n",
" use_normalized_coordinates=use_normalized_coordinates,\n",
" max_boxes_to_draw=200,\n",
" min_score_thresh=min_score_thresh,\n",
" agnostic_mode=False,\n",
" instance_masks=None,\n",
" line_thickness=4)\n",
"\n",
" plt.imshow(image)\n",
" plt.axis('off')\n",
" plt.title(f'Image-{i+1}')\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R3EgriELbJly"
},
"source": [
"### Visualization of train data\n",
"\n",
"The bounding box detection has two components\n",
" 1. Class label of the object detected (e.g.RBC)\n",
" 2. Percentage of match between predicted and ground truth bounding boxes.\n",
"\n",
"**Note**: The reason of everything is 100% is because we are visualising the groundtruth."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hdrsciGIZVNO"
},
"outputs": [],
"source": [
"buffer_size = 20\n",
"num_of_examples = 3\n",
"\n",
"raw_records = tf.data.TFRecordDataset(\n",
" exp_config.task.train_data.input_path).shuffle(\n",
" buffer_size=buffer_size).take(num_of_examples)\n",
"show_batch(raw_records, num_of_examples)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IrWkJPyEbMKg"
},
"source": [
"## Train and evaluate.\n",
"\n",
"We follow the COCO challenge tradition to evaluate the accuracy of object detection based on mAP(mean Average Precision). Please check [here](https://cocodataset.org/#detection-eval) for detail explanation of how evaluation metrics for detection task is done.\n",
"\n",
"**IoU**: is defined as the area of the intersection divided by the area of the union of a predicted bounding box and ground truth bounding box."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SCjHHXvfZXX1"
},
"outputs": [],
"source": [
"model, eval_logs = tfm.core.train_lib.run_experiment(\n",
" distribution_strategy=distribution_strategy,\n",
" task=task,\n",
" mode='train_and_eval',\n",
" params=exp_config,\n",
" model_dir=model_dir,\n",
" run_post_eval=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2Gd6uHLjbPKW"
},
"source": [
"## Load logs in tensorboard."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Q6iDRUVqZY86"
},
"outputs": [],
"source": [
"%load_ext tensorboard\n",
"%tensorboard --logdir '../trained_model/'"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AoL2MIJobReU"
},
"source": [
"## Saving and exporting the trained model.\n",
"\n",
"The `keras.Model` object returned by `train_lib.run_experiment` expects the data to be normalized by the dataset loader using the same mean and variance statiscics in `preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`. This export function handles those details, so you can pass `tf.uint8` images and get the correct results."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CmOBYXdXZah4"
},
"outputs": [],
"source": [
"export_saved_model_lib.export_inference_graph(\n",
" input_type='image_tensor',\n",
" batch_size=1,\n",
" input_image_size=[HEIGHT, WIDTH],\n",
" params=exp_config,\n",
" checkpoint_path=tf.train.latest_checkpoint(model_dir),\n",
" export_dir=export_dir)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_JhXopm8bU1g"
},
"source": [
"## Inference from trained model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EbD4j1uCZcIV"
},
"outputs": [],
"source": [
"def load_image_into_numpy_array(path):\n",
" \"\"\"Load an image from file into a numpy array.\n",
"\n",
" Puts image into numpy array to feed into tensorflow graph.\n",
" Note that by convention we put it into a numpy array with shape\n",
" (height, width, channels), where channels=3 for RGB.\n",
"\n",
" Args:\n",
" path: the file path to the image\n",
"\n",
" Returns:\n",
" uint8 numpy array with shape (img_height, img_width, 3)\n",
" \"\"\"\n",
" image = None\n",
" if(path.startswith('http')):\n",
" response = urlopen(path)\n",
" image_data = response.read()\n",
" image_data = BytesIO(image_data)\n",
" image = Image.open(image_data)\n",
" else:\n",
" image_data = tf.io.gfile.GFile(path, 'rb').read()\n",
" image = Image.open(BytesIO(image_data))\n",
"\n",
" (im_width, im_height) = image.size\n",
" return np.array(image.getdata()).reshape(\n",
" (1, im_height, im_width, 3)).astype(np.uint8)\n",
"\n",
"\n",
"\n",
"def build_inputs_for_object_detection(image, input_image_size):\n",
" \"\"\"Builds Object Detection model inputs for serving.\"\"\"\n",
" image, _ = resize_and_crop_image(\n",
" image,\n",
" input_image_size,\n",
" padded_size=input_image_size,\n",
" aug_scale_min=1.0,\n",
" aug_scale_max=1.0)\n",
" return image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o8bguhK_batq"
},
"source": [
"### Visualize test data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sOsDhYmyZd_m"
},
"outputs": [],
"source": [
"num_of_examples = 3\n",
"\n",
"test_ds = tf.data.TFRecordDataset(\n",
" '../bccd_coco_tfrecords/test-00000-of-00001.tfrecord').take(\n",
" num_of_examples)\n",
"show_batch(test_ds, num_of_examples)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kcYnb1Zfbba9"
},
"source": [
"### Importing SavedModel."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nQ6waz9rZfhy"
},
"outputs": [],
"source": [
"imported = tf.saved_model.load(export_dir)\n",
"model_fn = imported.signatures['serving_default']"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CtB4gfZ3bfiC"
},
"source": [
"### Visualize predictions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UTSfNZ6yZhEV"
},
"outputs": [],
"source": [
"input_image_size = (HEIGHT, WIDTH)\n",
"plt.figure(figsize=(20, 20))\n",
"min_score_thresh = 0.30 # Change minimum score for threshold to see all bounding boxes confidences.\n",
"\n",
"for i, serialized_example in enumerate(test_ds):\n",
" plt.subplot(1, 3, i+1)\n",
" decoded_tensors = tf_ex_decoder.decode(serialized_example)\n",
" image = build_inputs_for_object_detection(decoded_tensors['image'], input_image_size)\n",
" image = tf.expand_dims(image, axis=0)\n",
" image = tf.cast(image, dtype = tf.uint8)\n",
" image_np = image[0].numpy()\n",
" result = model_fn(image)\n",
" visualization_utils.visualize_boxes_and_labels_on_image_array(\n",
" image_np,\n",
" result['detection_boxes'][0].numpy(),\n",
" result['detection_classes'][0].numpy().astype(int),\n",
" result['detection_scores'][0].numpy(),\n",
" category_index=category_index,\n",
" use_normalized_coordinates=False,\n",
" max_boxes_to_draw=200,\n",
" min_score_thresh=min_score_thresh,\n",
" agnostic_mode=False,\n",
" instance_masks=None,\n",
" line_thickness=4)\n",
" plt.imshow(image_np)\n",
" plt.axis('off')\n",
"\n",
"plt.show()"
]
}
],
"metadata": {
"colab": {
"name": "object_detection.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "uY4QMaQw9Yvi"
},
"source": [
"##### Copyright 2022 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "NM0OBLSN9heW"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sg-GchQwFr_r"
},
"source": [
"# Semantic Segmentation with Model Garden\n",
"\n",
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/tfmodels/vision/semantic_segmentation\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/docs/vision/semantic_segmentation.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/docs/vision/semantic_segmentation.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/docs/vision/semantic_segmentation.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c6J4IoNfN9jp"
},
"source": [
"This tutorial trains a [DeepLabV3](https://arxiv.org/pdf/1706.05587.pdf) with [Mobilenet V2](https://arxiv.org/abs/1801.04381) as backbone model from the [TensorFlow Model Garden](https://pypi.org/project/tf-models-official/) package (tensorflow-models).\n",
"\n",
"\n",
"[Model Garden](https://www.tensorflow.org/tfmodels) contains a collection of state-of-the-art models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.\n",
"\n",
"**Dataset**: [Oxford-IIIT Pets](https://www.tensorflow.org/datasets/catalog/oxford_iiit_pet)\n",
"\n",
"* The Oxford-IIIT pet dataset is a 37 category pet image dataset with roughly 200 images for each class. The images have large variations in scale, pose and lighting. All images have an associated ground truth annotation of breed.\n",
"\n",
"\n",
"**This tutorial demonstrates how to:**\n",
"\n",
"1. Use models from the TensorFlow Models package.\n",
"2. Train/Fine-tune a pre-built DeepLabV3 with mobilenet as backbone for Semantic Segmentation.\n",
"3. Export the trained/tuned DeepLabV3 model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AlxYhP0XFnDn"
},
"source": [
"## Install necessary dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pXWAySwgaWpN"
},
"outputs": [],
"source": [
"!pip install -U -q \"tensorflow\u003e=2.9.2\" \"tf-models-official\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uExUsXlgaPD6"
},
"source": [
"## Import required libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mOmKZ3Vky5t9"
},
"outputs": [],
"source": [
"import os\n",
"import pprint\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from IPython import display"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nF8IHrXua_0b"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import tensorflow_datasets as tfds\n",
"\n",
"\n",
"import orbit\n",
"import tensorflow_models as tfm\n",
"from official.vision.data import tfrecord_lib\n",
"from official.vision.serving import export_saved_model_lib\n",
"from official.vision.utils.object_detection import visualization_utils\n",
"\n",
"pp = pprint.PrettyPrinter(indent=4) # Set Pretty Print Indentation\n",
"print(tf.__version__) # Check the version of tensorflow used\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gMs4l2dpaTd3"
},
"source": [
"## Custom dataset preparation for semantic segmentation\n",
"Models in Official repository (of model-garden) require models in a TFRecords dataformat.\n",
"\n",
"Please check [this resource](https://www.tensorflow.org/tutorials/load_data/tfrecord) to learn more about TFRecords data format.\n",
"\n",
"[Oxford_IIIT_pet:3](https://www.tensorflow.org/datasets/catalog/oxford_iiit_pet) dataset is taken from [Tensorflow Datasets](https://www.tensorflow.org/datasets/catalog/overview)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JpWK1Z-N3fHh"
},
"outputs": [],
"source": [
"(train_ds, val_ds, test_ds), info = tfds.load(\n",
" 'oxford_iiit_pet:3.*.*',\n",
" split=['train+test[:50%]', 'test[50%:80%]', 'test[80%:100%]'],\n",
" with_info=True)\n",
"info"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Sq6s11E1bMJB"
},
"source": [
"### Helper function to encode dataset as tfrecords"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NlEf_C-DjDHG"
},
"outputs": [],
"source": [
"def process_record(record):\n",
" keys_to_features = {\n",
" 'image/encoded': tfrecord_lib.convert_to_feature(\n",
" tf.io.encode_jpeg(record['image']).numpy()),\n",
" 'image/height': tfrecord_lib.convert_to_feature(record['image'].shape[0]),\n",
" 'image/width': tfrecord_lib.convert_to_feature(record['image'].shape[1]),\n",
" 'image/segmentation/class/encoded':tfrecord_lib.convert_to_feature(\n",
" tf.io.encode_png(record['segmentation_mask'] - 1).numpy())\n",
" }\n",
" example = tf.train.Example(\n",
" features=tf.train.Features(feature=keys_to_features))\n",
" return example"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FoapGlIebP9r"
},
"source": [
"### Write TFRecords to a folder"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dDbMn5q551LQ"
},
"outputs": [],
"source": [
"output_dir = './oxford_iiit_pet_tfrecords/'\n",
"LOG_EVERY = 100\n",
"if not os.path.exists(output_dir):\n",
" os.mkdir(output_dir)\n",
"\n",
"def write_tfrecords(dataset, output_path, num_shards=1):\n",
" writers = [\n",
" tf.io.TFRecordWriter(\n",
" output_path + '-%05d-of-%05d.tfrecord' % (i, num_shards))\n",
" for i in range(num_shards)\n",
" ]\n",
" for idx, record in enumerate(dataset):\n",
" if idx % LOG_EVERY == 0:\n",
" print('On image %d', idx)\n",
" tf_example = process_record(record)\n",
" writers[idx % num_shards].write(tf_example.SerializeToString())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QHDD-D7rbZj7"
},
"source": [
"### Write training data as TFRecords"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qxJnVUfT0qBJ"
},
"outputs": [],
"source": [
"output_train_tfrecs = output_dir + 'train'\n",
"write_tfrecords(train_ds, output_train_tfrecs, num_shards=10)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ap55RwVFbhtu"
},
"source": [
"### Write validation data as TFRecords"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Fgq-VxF79ucR"
},
"outputs": [],
"source": [
"output_val_tfrecs = output_dir + 'val'\n",
"write_tfrecords(val_ds, output_val_tfrecs, num_shards=5)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0AZoIEzRbxZu"
},
"source": [
"### Write test data as TFRecords"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QmwFmbP69t0U"
},
"outputs": [],
"source": [
"output_test_tfrecs = output_dir + 'test'\n",
"write_tfrecords(test_ds, output_test_tfrecs, num_shards=5)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uEFzV-6ZfBZW"
},
"source": [
"## Configure the DeepLabV3 Mobilenet model for custom dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_LPEIvLsqSaG"
},
"outputs": [],
"source": [
"train_data_tfrecords = './oxford_iiit_pet_tfrecords/train*'\n",
"val_data_tfrecords = './oxford_iiit_pet_tfrecords/val*'\n",
"test_data_tfrecords = './oxford_iiit_pet_tfrecords/test*'\n",
"trained_model = './trained_model/'\n",
"export_dir = './exported_model/'"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1ZlSiSRyb1Q6"
},
"source": [
"In Model Garden, the collections of parameters that define a model are called *configs*. Model Garden can create a config based on a known set of parameters via a [factory](https://en.wikipedia.org/wiki/Factory_method_pattern).\n",
"\n",
"\n",
"Use the `mnv2_deeplabv3_pascal` experiment configuration, as defined by `tfm.vision.configs.semantic_segmentation.mnv2_deeplabv3_pascal`.\n",
"\n",
"Please find all the registered experiements [here](https://www.tensorflow.org/api_docs/python/tfm/core/exp_factory/get_exp_config)\n",
"\n",
"The configuration defines an experiment to train a [DeepLabV3](https://arxiv.org/pdf/1706.05587.pdf) model with MobilenetV2 as backbone and [ASPP](https://arxiv.org/pdf/1606.00915v2.pdf) as decoder.\n",
"\n",
"There are also other alternative experiments available such as\n",
"\n",
"* `seg_deeplabv3_pascal`\n",
"* `seg_deeplabv3plus_pascal`\n",
"* `seg_resnetfpn_pascal`\n",
"* `mnv2_deeplabv3plus_cityscapes`\n",
"\n",
"and more. One can switch to them by changing the experiment name argument to the `get_exp_config` function.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bj5UZ6BkfJCX"
},
"outputs": [],
"source": [
"exp_config = tfm.core.exp_factory.get_exp_config('mnv2_deeplabv3_pascal')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "B8jyG-jGIdFs"
},
"outputs": [],
"source": [
"model_ckpt_path = './model_ckpt/'\n",
"if not os.path.exists(model_ckpt_path):\n",
" os.mkdir(model_ckpt_path)\n",
"\n",
"!gsutil cp gs://tf_model_garden/cloud/vision-2.0/deeplab/deeplabv3_mobilenetv2_coco/best_ckpt-63.data-00000-of-00001 './model_ckpt/'\n",
"!gsutil cp gs://tf_model_garden/cloud/vision-2.0/deeplab/deeplabv3_mobilenetv2_coco/best_ckpt-63.index './model_ckpt/'"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QBYvVFZXhSGQ"
},
"source": [
"### Adjust the model and dataset configurations so that it works with custom dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "o_Z_vWW9-5Sy"
},
"outputs": [],
"source": [
"num_classes = 3\n",
"WIDTH, HEIGHT = 128, 128\n",
"input_size = [HEIGHT, WIDTH, 3]\n",
"BATCH_SIZE = 16\n",
"\n",
"# Backbone Config\n",
"exp_config.task.init_checkpoint = model_ckpt_path + 'best_ckpt-63'\n",
"exp_config.task.freeze_backbone = True\n",
"\n",
"# Model Config\n",
"exp_config.task.model.num_classes = num_classes\n",
"exp_config.task.model.input_size = input_size\n",
"\n",
"# Training Data Config\n",
"exp_config.task.train_data.aug_scale_min = 1.0\n",
"exp_config.task.train_data.aug_scale_max = 1.0\n",
"exp_config.task.train_data.input_path = train_data_tfrecords\n",
"exp_config.task.train_data.global_batch_size = BATCH_SIZE\n",
"exp_config.task.train_data.dtype = 'float32'\n",
"exp_config.task.train_data.output_size = [HEIGHT, WIDTH]\n",
"exp_config.task.train_data.preserve_aspect_ratio = False\n",
"exp_config.task.train_data.seed = 21 # Reproducable Training Data\n",
"\n",
"# Validation Data Config\n",
"exp_config.task.validation_data.input_path = val_data_tfrecords\n",
"exp_config.task.validation_data.global_batch_size = BATCH_SIZE\n",
"exp_config.task.validation_data.dtype = 'float32'\n",
"exp_config.task.validation_data.output_size = [HEIGHT, WIDTH]\n",
"exp_config.task.validation_data.preserve_aspect_ratio = False\n",
"exp_config.task.validation_data.groundtruth_padded_size = [HEIGHT, WIDTH]\n",
"exp_config.task.validation_data.seed = 21 # Reproducable Validation Data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0HDg5eKniMGJ"
},
"source": [
"### Adjust the trainer configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WASJZ3gUH8ni"
},
"outputs": [],
"source": [
"logical_device_names = [logical_device.name\n",
" for logical_device in tf.config.list_logical_devices()]\n",
"\n",
"if 'GPU' in ''.join(logical_device_names):\n",
" print('This may be broken in Colab.')\n",
" device = 'GPU'\n",
"elif 'TPU' in ''.join(logical_device_names):\n",
" print('This may be broken in Colab.')\n",
" device = 'TPU'\n",
"else:\n",
" print('Running on CPU is slow, so only train for a few steps.')\n",
" device = 'CPU'\n",
"\n",
"\n",
"train_steps = 2000\n",
"exp_config.trainer.steps_per_loop = int(train_ds.__len__().numpy() // BATCH_SIZE)\n",
"\n",
"exp_config.trainer.summary_interval = exp_config.trainer.steps_per_loop # steps_per_loop = num_of_validation_examples // eval_batch_size\n",
"exp_config.trainer.checkpoint_interval = exp_config.trainer.steps_per_loop\n",
"exp_config.trainer.validation_interval = exp_config.trainer.steps_per_loop\n",
"exp_config.trainer.validation_steps = int(train_ds.__len__().numpy() // BATCH_SIZE) # validation_steps = num_of_validation_examples // eval_batch_size\n",
"exp_config.trainer.train_steps = train_steps\n",
"exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = exp_config.trainer.steps_per_loop\n",
"exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'\n",
"exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps\n",
"exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1\n",
"exp_config.trainer.optimizer_config.warmup.linear.warmup_learning_rate = 0.05"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R66w5MwkiO8Z"
},
"source": [
"### Print the modified configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ckpjzrqfhoSn"
},
"outputs": [],
"source": [
"pp.pprint(exp_config.as_dict())\n",
"display.Javascript('google.colab.output.setIframeHeight(\"500px\");')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FYwzdGKAiSOV"
},
"source": [
"### Set up the distribution strategy."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iwiOuYRRqdBi"
},
"outputs": [],
"source": [
"# Setting up the Strategy\n",
"if exp_config.runtime.mixed_precision_dtype == tf.float16:\n",
" tf.keras.mixed_precision.set_global_policy('mixed_float16')\n",
"\n",
"if 'GPU' in ''.join(logical_device_names):\n",
" distribution_strategy = tf.distribute.MirroredStrategy()\n",
"elif 'TPU' in ''.join(logical_device_names):\n",
" tf.tpu.experimental.initialize_tpu_system()\n",
" tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')\n",
" distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
"else:\n",
" print('Warning: this will be really slow.')\n",
" distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])\n",
"\n",
"print(\"Done\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZLtk1GIIiVR2"
},
"source": [
"## Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.\n",
"\n",
"The `Task` object has all the methods necessary for building the dataset, building the model, and running training \u0026 evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ASTB5D2UISSr"
},
"outputs": [],
"source": [
"model_dir = './trained_model/'\n",
"\n",
"with distribution_strategy.scope():\n",
" task = tfm.core.task_factory.get_task(exp_config.task, logging_dir=model_dir)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YIQ26TW-ihzA"
},
"source": [
"## Visualize a batch of the data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "412WyIUAIdCr"
},
"outputs": [],
"source": [
"for images, masks in task.build_inputs(exp_config.task.train_data).take(1):\n",
" print()\n",
" print(f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')\n",
" print(f'masks.shape: {str(masks[\"masks\"].shape):16} images.dtype: {masks[\"masks\"].dtype!r}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3GgluDVJixMd"
},
"source": [
"### Helper function for visualizing the results from TFRecords"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1kueMMfERvLx"
},
"outputs": [],
"source": [
"def display(display_list):\n",
" plt.figure(figsize=(15, 15))\n",
"\n",
" title = ['Input Image', 'True Mask', 'Predicted Mask']\n",
"\n",
" for i in range(len(display_list)):\n",
" plt.subplot(1, len(display_list), i+1)\n",
" plt.title(title[i])\n",
" plt.imshow(tf.keras.utils.array_to_img(display_list[i]))\n",
"\n",
"\n",
" plt.axis('off')\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZCtt09G7i3dq"
},
"source": [
"### Visualization of training data\n",
"\n",
"Image Title represents what is depicted from the image.\n",
"\n",
"Same helper function can be used while visualizing predicted mask"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YwUPf9V2B6SR"
},
"outputs": [],
"source": [
"num_examples = 3\n",
"\n",
"for images, masks in task.build_inputs(exp_config.task.train_data).take(num_examples):\n",
" display([images[0], masks['masks'][0]])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MeJ5w8KfjMmP"
},
"source": [
"## Train and evaluate\n",
"**IoU**: is defined as the area of the intersection divided by the area of the union of a predicted mask and ground truth mask."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ru3aHTCySHoH"
},
"outputs": [],
"source": [
"model, eval_logs = tfm.core.train_lib.run_experiment(\n",
" distribution_strategy=distribution_strategy,\n",
" task=task,\n",
" mode='train_and_eval',\n",
" params=exp_config,\n",
" model_dir=model_dir,\n",
" run_post_eval=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vt3WmtxhjfGe"
},
"source": [
"## Load logs in tensorboard"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "A9rct_7BoJFb"
},
"outputs": [],
"source": [
"%load_ext tensorboard\n",
"%tensorboard --logdir './trained_model'"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "v6XaGoUuji7P"
},
"source": [
"## Saving and exporting the trained model\n",
"\n",
"The `keras.Model` object returned by `train_lib.run_experiment` expects the data to be normalized by the dataset loader using the same mean and variance statiscics in `preprocess_ops.normalize_image(image, offset=MEAN_RGB, scale=STDDEV_RGB)`. This export function handles those details, so you can pass `tf.uint8` images and get the correct results."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GVsnyqzdnxHd"
},
"outputs": [],
"source": [
"export_saved_model_lib.export_inference_graph(\n",
" input_type='image_tensor',\n",
" batch_size=1,\n",
" input_image_size=[HEIGHT, WIDTH],\n",
" params=exp_config,\n",
" checkpoint_path=tf.train.latest_checkpoint(model_dir),\n",
" export_dir=export_dir)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nM1S-tjIjvAr"
},
"source": [
"## Importing SavedModel"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Nxi9pEwluUcT"
},
"outputs": [],
"source": [
"imported = tf.saved_model.load(export_dir)\n",
"model_fn = imported.signatures['serving_default']"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LbBfl6AUj_My"
},
"source": [
"## Visualize predictions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qifGt_ohpFhn"
},
"outputs": [],
"source": [
"def create_mask(pred_mask):\n",
" pred_mask = tf.math.argmax(pred_mask, axis=-1)\n",
" pred_mask = pred_mask[..., tf.newaxis]\n",
" return pred_mask[0]\n",
"\n",
"\n",
"for record in test_ds.take(15):\n",
" image = tf.image.resize(record['image'], size=[HEIGHT, WIDTH])\n",
" image = tf.cast(image, dtype=tf.uint8)\n",
" mask = tf.image.resize(record['segmentation_mask'], size=[HEIGHT, WIDTH])\n",
" predicted_mask = model_fn(tf.expand_dims(record['image'], axis=0))\n",
" display([image, mask, create_mask(predicted_mask['logits'])])"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "semantic_segmentation.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
2024-03-13 04:28:11.344384: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-03-13 04:28:14.571490: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1639] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 31985 MB memory: -> device: 0, name: Z100L, pci bus id: 0000:a6:00.0
2024-03-13 04:28:14.608583: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1639] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 31985 MB memory: -> device: 1, name: Z100L, pci bus id: 0000:a9:00.0
2024-03-13 04:28:14.645036: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1639] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 31985 MB memory: -> device: 2, name: Z100L, pci bus id: 0000:ac:00.0
2024-03-13 04:28:14.684236: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1639] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 31985 MB memory: -> device: 3, name: Z100L, pci bus id: 0000:af:00.0
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
I0313 04:28:14.736542 139980117518144 mirrored_strategy.py:419] Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
2024-03-13 04:28:14.765982: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:15.232599: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:15.233864: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:15.235327: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:15.671567: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:15.672859: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:15.674294: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.098994: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.100178: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.101548: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.547104: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.548283: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
I0313 04:28:16.548994 139980117518144 resnet_ctl_imagenet_main.py:136] Training 1 epochs, each epoch has 10009 steps, total steps: 10009; Eval 391 steps
2024-03-13 04:28:16.573623: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.581840: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.594685: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.598601: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.601076: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.602759: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.603854: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.605730: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.607155: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.622270: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.623464: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.625069: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.626302: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.627313: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.628752: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.630130: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.631195: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.632220: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.645480: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.648862: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.649929: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.651363: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.652738: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.653753: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.655435: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.656438: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.658099: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.659134: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
I0313 04:28:16.679222 139980117518144 cross_device_ops.py:617] Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
I0313 04:28:16.687673 139980117518144 cross_device_ops.py:617] Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
2024-03-13 04:28:16.707794: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.709069: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.710450: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.711859: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.712907: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.714579: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.715608: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.717297: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.718329: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
I0313 04:28:16.748045 139980117518144 cross_device_ops.py:617] Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
I0313 04:28:16.755074 139980117518144 cross_device_ops.py:617] Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
2024-03-13 04:28:16.771472: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.772661: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.774041: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.775460: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.776460: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.778181: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.779209: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.780860: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.781871: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
I0313 04:28:16.811109 139980117518144 cross_device_ops.py:617] Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
I0313 04:28:16.819525 139980117518144 cross_device_ops.py:617] Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
2024-03-13 04:28:16.835657: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.836834: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.838181: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.839611: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.840612: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.842259: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.843249: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.844898: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.845926: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.855669: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.856771: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.858351: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.859708: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.860688: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.862320: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.863288: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.864847: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.865820: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
I0313 04:28:16.884291 139980117518144 cross_device_ops.py:617] Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
I0313 04:28:16.891215 139980117518144 cross_device_ops.py:617] Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
I0313 04:28:16.934982 139980117518144 cross_device_ops.py:617] Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
I0313 04:28:16.942033 139980117518144 cross_device_ops.py:617] Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
2024-03-13 04:28:16.961922: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.963097: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.964466: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.965868: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.966887: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.968570: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.969585: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.971248: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:16.972265: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.296712: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.297886: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.299198: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.300550: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.301519: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.303106: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.304067: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.305666: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.306647: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.316309: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.317396: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.318626: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.319935: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.320919: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.322515: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.323474: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.325029: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.325985: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.367172: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.368369: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.369735: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.371204: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.372245: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.373972: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.374991: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.376713: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.377746: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.429503: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.430656: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.431963: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.433311: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.434308: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.435906: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.436879: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.438475: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.439455: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.449410: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.450515: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.451772: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.453086: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.454056: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.455639: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.456605: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.458180: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.459143: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.497290: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.498467: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.499828: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.501263: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.502311: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.503957: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.504978: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.506693: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.507712: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.562944: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.564170: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.565587: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.566969: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.567984: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.569659: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.570664: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.572325: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:17.573339: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.065391: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.066552: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.067880: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.069312: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.070296: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.071980: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.072962: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.074622: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.075601: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.275576: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.276786: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.278189: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.279709: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.280743: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.282788: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.283820: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.285621: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.286635: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.340305: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.341526: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.342910: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.344464: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.345685: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.347697: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.348750: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.350565: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.351639: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.362203: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.363353: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.364745: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.366270: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.367467: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.369182: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.370280: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.371987: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.373042: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.413000: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.414173: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.415489: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.416968: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.417960: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.419760: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.420758: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.422531: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.423525: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.478821: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.480011: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.481457: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.483080: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.484118: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.485854: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.486882: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.488764: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:18.489821: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.320024: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.321442: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.322816: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.324356: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.325364: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.327122: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.328132: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.329913: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.330913: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.382965: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.384206: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.385527: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.387504: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.388518: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.390841: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.391858: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.394151: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.395175: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.447148: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.448397: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.449739: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.451434: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.452475: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.454435: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.455456: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.457413: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.458413: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.468662: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.469858: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.471143: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.472541: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.473544: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.475188: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.476185: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.477798: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.478791: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.527362: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.529194: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.531246: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.534107: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.535680: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.539021: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.540612: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.543964: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.545649: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.608597: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.610069: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.611542: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.613247: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.614383: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.616399: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.617433: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.619477: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.620489: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.968730: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.970807: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.971880: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.973226: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.975104: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.976135: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.978324: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.979326: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.981433: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.982444: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.984592: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.985609: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.986815: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.988093: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.989086: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.990666: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.991638: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.993230: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:19.994206: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.006820: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.008156: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.009183: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.010445: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.011435: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.012597: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.013593: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.014890: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.015994: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.016978: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.018261: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.019372: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.020332: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.042820: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.051743: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.092545: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
I0313 04:28:20.093187 139980117518144 imagenet_preprocessing.py:321] Sharding the dataset: input_pipeline_id=0 num_input_pipelines=1
2024-03-13 04:28:20.095150: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.096700: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.097816: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.119438: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.121664: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.862418: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.865859: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.867397: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.868808: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:20.870344: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
I0313 04:28:20.872307 139980117518144 imagenet_preprocessing.py:321] Sharding the dataset: input_pipeline_id=0 num_input_pipelines=1
2024-03-13 04:28:20.888647: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.189232: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.192187: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.193460: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.194801: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
I0313 04:28:21.196566 139980117518144 controller.py:449] restoring or initializing model...
2024-03-13 04:28:21.198054: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
I0313 04:28:21.212172 139980117518144 controller.py:267] train | step: 0 | training until step 10009...
2024-03-13 04:28:21.214986: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.215979: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.218003: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.219364: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.221421: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.222584: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.223835: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.225033: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.226919: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.227930: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.228933: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.230033: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.268401: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.269493: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.282730: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.284263: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.285458: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.286467: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.353194: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.360303: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.394527: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.429071: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.465038: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.501643: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.507564: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.509679: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.510698: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.519137: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.528069: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.530226: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.531838: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.534861: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.535884: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.543008: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.551755: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.553806: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.555424: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.558487: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.559450: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.566696: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.575370: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.577342: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.578986: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.581939: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.582941: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.590079: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:28:21.598739: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
INFO:tensorflow:Collective all_reduce tensors: 161 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
I0313 04:28:29.238857 139980117518144 cross_device_ops.py:1152] Collective all_reduce tensors: 161 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 161 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
I0313 04:28:42.643724 139980117518144 cross_device_ops.py:1152] Collective all_reduce tensors: 161 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
2024-03-13 04:28:59.271449: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:29:00.109181: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 04:29:00.116453: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:01:59.815815: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
I0313 05:01:59.824195 139980117518144 keras_utils.py:145] TimeHistory: 2018.54 seconds, 634.69 examples/second between steps 0 and 10009
2024-03-13 05:01:59.826595: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
I0313 05:01:59.830389 139980117518144 controller.py:520] train | step: 10009 | steps/sec: 5.0 | output: {'train_accuracy': 0.028023997, 'train_loss': 2.7792585}
2024-03-13 05:01:59.833732: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:01:59.835244: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:01:59.838939: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
I0313 05:01:59.854098 139980117518144 controller.py:565] Sync on async checkpoint saving.
I0313 05:01:59.855006 139980117518144 controller.py:310] eval | step: 10009 | running 391 steps of evaluation...
2024-03-13 05:01:59.871633: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:01:59.874045: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:01:59.934356: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:01:59.962329: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:00.006142: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:00.044122: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:00.081467: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:00.125970: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:00.131090: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:00.133492: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:00.141808: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:00.151565: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:00.154296: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:00.156456: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:00.164225: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:00.173740: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:00.176457: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:00.178620: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:00.186541: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:00.196406: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:00.199080: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:00.201144: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:00.209061: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:00.218698: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:00.221128: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:00.223675: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:05.025140: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
2024-03-13 05:02:05.150253: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
I0313 05:02:37.262611 139980117518144 controller.py:331] eval | step: 10009 | steps/sec: 10.5 | eval time: 37.4 sec | output:
{'steps_per_second': 10.452739501846827,
'test_accuracy': 0.09744,
'test_loss': 1.2336843}
I0313 05:02:37.272547 139980117518144 controller.py:565] Sync on async checkpoint saving.
2024-03-13 05:02:37.291930: I tensorflow/core/common_runtime/gpu_fusion_pass.cc:508] ROCm Fusion is enabled.
I0313 05:02:37.292726 139980117518144 resnet_ctl_imagenet_main.py:189] Run stats:
{'eval_loss': 1.2336843, 'eval_acc': 0.09744, 'train_loss': 2.7792585, 'train_acc': 0.028023997, 'step_timestamp_log': ['BatchTimestamp<batch_index: 0, timestamp: 1710304101.2805958>', 'BatchTimestamp<batch_index: 10009, timestamp: 1710306119.823979>'], 'train_finish_time': 1710306157.2727807, 'avg_exp_per_second': 634.6907446249477}
restoring or initializing model...
train | step: 0 | training until step 10009...
train | step: 10009 | steps/sec: 5.0 | output: {'train_accuracy': 0.028023997, 'train_loss': 2.7792585}
eval | step: 10009 | running 391 steps of evaluation...
eval | step: 10009 | steps/sec: 10.5 | eval time: 37.4 sec | output:
{'steps_per_second': 10.452739501846827,
'test_accuracy': 0.09744,
'test_loss': 1.2336843}
# Offically Supported TensorFlow 2.1+ Models on Cloud TPU
## Natural Language Processing
* [bert](nlp/bert): A powerful pre-trained language representation model:
BERT, which stands for Bidirectional Encoder Representations from
Transformers.
[BERT FineTuning with Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/bert-2.x) provides step by step instructions on Cloud TPU training. You can look [Bert MNLI Tensorboard.dev metrics](https://tensorboard.dev/experiment/LijZ1IrERxKALQfr76gndA) for MNLI fine tuning task.
* [transformer](nlp/transformer): A transformer model to translate the WMT
English to German dataset.
[Training transformer on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/transformer-2.x) for step by step instructions on Cloud TPU training.
## Computer Vision
* [efficientnet](vision/image_classification): A family of convolutional
neural networks that scale by balancing network depth, width, and
resolution and can be used to classify ImageNet's dataset of 1000 classes.
See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/KnaWjrq5TXGfv0NW5m7rpg/#scalars).
* [mnist](vision/image_classification): A basic model to classify digits
from the MNIST dataset. See [Running MNIST on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/mnist-2.x) tutorial and [Tensorboard.dev metrics](https://tensorboard.dev/experiment/mIah5lppTASvrHqWrdr6NA).
* [mask-rcnn](vision/detection): An object detection and instance segmentation model. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/LH7k0fMsRwqUAcE09o9kPA).
* [resnet](vision/image_classification): A deep residual network that can
be used to classify ImageNet's dataset of 1000 classes.
See [Training ResNet on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/resnet-2.x) tutorial and [Tensorboard.dev metrics](https://tensorboard.dev/experiment/CxlDK8YMRrSpYEGtBRpOhg).
* [retinanet](vision/detection): A fast and powerful object detector. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/b8NRnWU3TqG6Rw0UxueU6Q).
* [shapemask](vision/detection): An object detection and instance segmentation model using shape priors. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/ZbXgVoc6Rf6mBRlPj0JpLA).
## Recommendation
* [dlrm](recommendation/ranking): [Deep Learning Recommendation Model for
Personalization and Recommendation Systems](https://arxiv.org/abs/1906.00091).
* [dcn v2](recommendation/ranking): [Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535).
* [ncf](recommendation): Neural Collaborative Filtering. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/0k3gKjZlR1ewkVTRyLB6IQ).
<div align="center">
<img src="https://storage.googleapis.com/tf_model_garden/tf_model_garden_logo.png">
</div>
# TensorFlow Official Models
The TensorFlow official models are a collection of models
that use TensorFlow’s high-level APIs.
They are intended to be well-maintained, tested, and kept up to date
with the latest TensorFlow API.
They should also be reasonably optimized for fast performance while still
being easy to read.
These models are used as end-to-end tests, ensuring that the models run
with the same or improved speed and performance with each new TensorFlow build.
The API documentation of the latest stable release is published to
[tensorflow.org](https://www.tensorflow.org/api_docs/python/tfm).
## More models to come!
The team is actively developing new models.
In the near future, we will add:
* State-of-the-art language understanding models.
* State-of-the-art image classification models.
* State-of-the-art object detection and instance segmentation models.
* State-of-the-art video classification models.
## Table of Contents
- [Models and Implementations](#models-and-implementations)
* [Computer Vision](#computer-vision)
+ [Image Classification](#image-classification)
+ [Object Detection and Segmentation](#object-detection-and-segmentation)
+ [Video Classification](#video-classification)
* [Natural Language Processing](#natural-language-processing)
* [Recommendation](#recommendation)
- [How to get started with the official models](#how-to-get-started-with-the-official-models)
- [Contributions](#contributions)
## Models and Implementations
### [Computer Vision](vision/README.md)
#### Image Classification
| Model | Reference (Paper) |
|-------|-------------------|
| [ResNet](vision/MODEL_GARDEN.md) | [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) |
| [ResNet-RS](vision/MODEL_GARDEN.md) | [Revisiting ResNets: Improved Training and Scaling Strategies](https://arxiv.org/abs/2103.07579) |
| [EfficientNet](vision/MODEL_GARDEN.md) | [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946) |
| [Vision Transformer](vision/MODEL_GARDEN.md) | [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) |
#### Object Detection and Segmentation
| Model | Reference (Paper) |
|-------|-------------------|
| [RetinaNet](vision/MODEL_GARDEN.md) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) |
| [Mask R-CNN](vision/MODEL_GARDEN.md) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) |
| [YOLO](projects/yolo/README.md) | [YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors](https://arxiv.org/abs/2207.02696) |
| [SpineNet](vision/MODEL_GARDEN.md) | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://arxiv.org/abs/1912.05027) |
| [Cascade RCNN-RS and RetinaNet-RS](vision/MODEL_GARDEN.md) | [Simple Training Strategies and Model Scaling for Object Detection](https://arxiv.org/abs/2107.00057)|
#### Video Classification
| Model | Reference (Paper) |
|-------|-------------------|
| [Mobile Video Networks (MoViNets)](projects/movinet) | [MoViNets: Mobile Video Networks for Efficient Video Recognition](https://arxiv.org/abs/2103.11511) |
### [Natural Language Processing](nlp/README.md)
#### Pre-trained Language Model
| Model | Reference (Paper) |
|-------|-------------------|
| [ALBERT](nlp/MODEL_GARDEN.md#available-model-configs) | [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942) |
| [BERT](nlp/MODEL_GARDEN.md#available-model-configs) | [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) |
| [ELECTRA](nlp/tasks/electra_task.py) | [ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators](https://arxiv.org/abs/2003.10555) |
#### Neural Machine Translation
| Model | Reference (Paper) |
|-------|-------------------|
| [Transformer](nlp/MODEL_GARDEN.md#available-model-configs) | [Attention Is All You Need](https://arxiv.org/abs/1706.03762) |
#### Natural Language Generation
| Model | Reference (Paper) |
|-------|-------------------|
| [NHNet (News Headline generation model)](projects/nhnet) | [Generating Representative Headlines for News Stories](https://arxiv.org/abs/2001.09386) |
#### Knowledge Distillation
| Model | Reference (Paper) |
|-------|-------------------|
| [MobileBERT](projects/mobilebert) | [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) |
### Recommendation
Model | Reference (Paper)
-------------------------------- | -----------------
[DLRM](recommendation/ranking) | [Deep Learning Recommendation Model for Personalization and Recommendation Systems](https://arxiv.org/abs/1906.00091)
[DCN v2](recommendation/ranking) | [Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535)
[NCF](recommendation) | [Neural Collaborative Filtering](https://arxiv.org/abs/1708.05031)
## How to get started with the official models
* The official models in the master branch are developed using
[master branch of TensorFlow 2](https://github.com/tensorflow/tensorflow/tree/master).
When you clone (the repository) or download (`pip` binary) master branch of
official models , master branch of TensorFlow gets downloaded as a
dependency. This is equivalent to the following.
```shell
pip3 install tf-models-nightly
pip3 install tensorflow-text-nightly # when model uses `nlp` packages
```
* Incase of stable versions, targeting a specific release, Tensorflow-models
repository version numbers match with the target TensorFlow release. For
example, [TensorFlow-models v2.8.x](https://github.com/tensorflow/models/releases/tag/v2.8.0)
is compatible with [TensorFlow v2.8.x](https://github.com/tensorflow/tensorflow/releases/tag/v2.8.0).
This is equivalent to the following:
```shell
pip3 install tf-models-official==2.8.0
pip3 install tensorflow-text==2.8.0 # when models in uses `nlp` packages
```
Starting from 2.9.x release, we release the modeling library as
`tensorflow_models` package and users can `import tensorflow_models` directly to
access to the exported symbols. If you are
using the latest nightly version or github code directly, please follow the
docstrings in the github.
Please follow the below steps before running models in this repository.
### Requirements
* The latest TensorFlow Model Garden release and the latest TensorFlow 2
* If you are on a version of TensorFlow earlier than 2.2, please
upgrade your TensorFlow to [the latest TensorFlow 2](https://www.tensorflow.org/install/).
* Python 3.7+
Our integration tests run with Python 3.7. Although Python 3.6 should work, we
don't recommend earlier versions.
### Installation
Please check [here](https://github.com/tensorflow/models#Installation) for the
instructions.
Available pypi packages:
* [tf-models-official](https://pypi.org/project/tf-models-official/)
* [tf-models-nightly](https://pypi.org/project/tf-models-nightly/): nightly
release with the latest changes.
* [tf-models-no-deps](https://pypi.org/project/tf-models-no-deps/): without
`tensorflow` and `tensorflow-text` in the `install_requires` list.
## Contributions
If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common benchmark class for model garden models."""
import os
import pprint
from typing import Optional
# Import libraries
from absl import logging
import gin
import tensorflow as tf
from tensorflow.python.platform import benchmark # pylint: disable=unused-import
from official.common import registry_imports # pylint: disable=unused-import
from official.benchmark import benchmark_lib
from official.benchmark import benchmark_definitions
from official.benchmark import config_utils
from official.core import exp_factory
from official.modeling import hyperparams
def _get_benchmark_params(benchmark_models, eval_tflite=False):
"""Formats benchmark params into a list."""
parameterized_benchmark_params = []
for _, benchmarks in benchmark_models.items():
for name, params in benchmarks.items():
if eval_tflite:
execution_modes = ['performance', 'tflite_accuracy']
else:
execution_modes = ['performance', 'accuracy']
for execution_mode in execution_modes:
benchmark_name = '{}.{}'.format(name, execution_mode)
benchmark_params = (
benchmark_name, # First arg is used by ParameterizedBenchmark.
benchmark_name,
params.get('benchmark_function') or benchmark_lib.run_benchmark,
params['experiment_type'],
execution_mode,
params['platform'],
params['precision'],
params['metric_bounds'],
params.get('config_files') or [],
params.get('params_override') or None,
params.get('gin_file') or [])
parameterized_benchmark_params.append(benchmark_params)
return parameterized_benchmark_params
class BaseBenchmark( # pylint: disable=undefined-variable
tf.test.Benchmark, metaclass=benchmark.ParameterizedBenchmark):
"""Common Benchmark.
benchmark.ParameterizedBenchmark is used to auto create benchmarks from
benchmark method according to the benchmarks defined in
benchmark_definitions. The name of the new benchmark methods is
benchmark__{benchmark_name}. _get_benchmark_params is used to generate the
benchmark name and args.
"""
_benchmark_parameters = _get_benchmark_params(
benchmark_definitions.VISION_BENCHMARKS) + _get_benchmark_params(
benchmark_definitions.NLP_BENCHMARKS) + _get_benchmark_params(
benchmark_definitions.QAT_BENCHMARKS,
True) + _get_benchmark_params(
benchmark_definitions.TENSOR_TRACER_BENCHMARKS)
def __init__(self,
output_dir=None,
tpu=None,
tensorflow_models_path: Optional[str] = None):
"""Initialize class.
Args:
output_dir: Base directory to store all output for the test.
tpu: (optional) TPU name to use in a TPU benchmark.
tensorflow_models_path: Full path to tensorflow models directory. Needed
to locate config files.
"""
if os.getenv('BENCHMARK_OUTPUT_DIR'):
self.output_dir = os.getenv('BENCHMARK_OUTPUT_DIR')
elif output_dir:
self.output_dir = output_dir
else:
self.output_dir = '/tmp'
if os.getenv('BENCHMARK_TPU'):
self._resolved_tpu = os.getenv('BENCHMARK_TPU')
elif tpu:
self._resolved_tpu = tpu
else:
self._resolved_tpu = None
if os.getenv('TENSORFLOW_MODELS_PATH'):
self._tensorflow_models_path = os.getenv('TENSORFLOW_MODELS_PATH')
else:
self._tensorflow_models_path = tensorflow_models_path
def _get_model_dir(self, folder_name):
"""Returns directory to store info, e.g. saved model and event log."""
return os.path.join(self.output_dir, folder_name)
def benchmark(self,
benchmark_name,
benchmark_function,
experiment_type,
execution_mode,
platform,
precision,
metric_bounds,
config_files,
params_override,
gin_file):
with gin.unlock_config():
gin.parse_config_files_and_bindings([
config_utils.get_config_path(
g, base_dir=self._tensorflow_models_path) for g in gin_file
], None)
params = exp_factory.get_exp_config(experiment_type)
for config_file in config_files:
file_path = config_utils.get_config_path(
config_file, base_dir=self._tensorflow_models_path)
params = hyperparams.override_params_dict(
params, file_path, is_strict=True)
if params_override:
params = hyperparams.override_params_dict(
params, params_override, is_strict=True)
# platform in format tpu.[n]x[n] or gpu.[n]
if 'tpu' in platform:
params.runtime.distribution_strategy = 'tpu'
params.runtime.tpu = self._resolved_tpu
elif 'gpu' in platform:
params.runtime.num_gpus = int(platform.split('.')[-1])
params.runtime.distribution_strategy = 'mirrored'
else:
NotImplementedError('platform :{} is not supported'.format(platform))
params.runtime.mixed_precision_dtype = precision
params.validate()
params.lock()
tf.io.gfile.makedirs(self._get_model_dir(benchmark_name))
hyperparams.save_params_dict_to_yaml(
params,
os.path.join(self._get_model_dir(benchmark_name), 'params.yaml'))
pp = pprint.PrettyPrinter()
logging.info('Final experiment parameters: %s',
pp.pformat(params.as_dict()))
benchmark_data = benchmark_function(
execution_mode, params, self._get_model_dir(benchmark_name))
metrics = []
if execution_mode in ['accuracy', 'tflite_accuracy']:
for metric_bound in metric_bounds:
metric = {
'name': metric_bound['name'],
'value': benchmark_data['metrics'][metric_bound['name']],
'min_value': metric_bound['min_value'],
'max_value': metric_bound['max_value']
}
metrics.append(metric)
metrics.append({'name': 'startup_time',
'value': benchmark_data['startup_time']})
metrics.append({'name': 'exp_per_second',
'value': benchmark_data['examples_per_second']})
self.report_benchmark(
iters=-1,
wall_time=benchmark_data['wall_time'],
metrics=metrics,
extras={'model_name': benchmark_name.split('.')[0],
'platform': platform,
'implementation': 'orbit.ctl',
'parameters': precision})
if __name__ == '__main__':
tf.test.main()
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model garden benchmark definitions."""
# tf-vision benchmarks
IMAGE_CLASSIFICATION_BENCHMARKS = {
'image_classification.resnet50.tpu.4x4.bf16':
dict(
experiment_type='resnet_imagenet',
platform='tpu.4x4',
precision='bfloat16',
metric_bounds=[{
'name': 'accuracy',
'min_value': 0.76,
'max_value': 0.77
}],
config_files=[('official/vision/configs/experiments/'
'image_classification/imagenet_resnet50_tpu.yaml')]),
'image_classification.resnet50.gpu.8.fp16':
dict(
experiment_type='resnet_imagenet',
platform='gpu.8',
precision='float16',
metric_bounds=[{
'name': 'accuracy',
'min_value': 0.76,
'max_value': 0.77
}],
config_files=[('official/vision/configs/experiments/'
'image_classification/imagenet_resnet50_gpu.yaml')])
}
VISION_BENCHMARKS = {
'image_classification': IMAGE_CLASSIFICATION_BENCHMARKS,
}
NLP_BENCHMARKS = {
}
QAT_BENCHMARKS = {
}
TENSOR_TRACER_BENCHMARKS = {
}
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFM common benchmark training driver."""
import os
import time
from typing import Any, Mapping, Optional
from absl import logging
import orbit
import tensorflow as tf
from official.benchmark import tflite_utils
from official.common import distribute_utils
from official.core import config_definitions
from official.core import task_factory
from official.core import train_utils
from official.modeling import performance
from official.projects.token_dropping import experiment_configs # pylint: disable=unused-import
class _OutputRecorderAction:
"""Simple `Action` that saves the outputs passed to `__call__`."""
def __init__(self):
self.train_output = {}
def __call__(
self,
output: Optional[Mapping[str, tf.Tensor]] = None) -> Mapping[str, Any]:
self.train_output = {k: v.numpy() for k, v in output.items()
} if output else {}
def run_benchmark(
execution_mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
distribution_strategy: tf.distribute.Strategy = None
) -> Mapping[str, Any]:
"""Runs benchmark for a specific experiment.
Args:
execution_mode: A 'str', specifying the mode. Can be 'accuracy',
'performance', or 'tflite_accuracy'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
distribution_strategy: A tf.distribute.Strategy to use. If specified,
it will be used instead of inferring the strategy from params.
Returns:
benchmark_data: returns benchmark data in dict format.
Raises:
NotImplementedError: If try to use unsupported setup.
"""
# For GPU runs, allow option to set thread mode
if params.runtime.gpu_thread_mode:
os.environ['TF_GPU_THREAD_MODE'] = params.runtime.gpu_thread_mode
logging.info('TF_GPU_THREAD_MODE: %s', os.environ['TF_GPU_THREAD_MODE'])
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
strategy = distribution_strategy or distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
trainer = train_utils.create_trainer(
params,
task,
train=True,
evaluate=(execution_mode == 'accuracy'))
# Initialize the model if possible, e.g., from a pre-trained checkpoint.
trainer.initialize()
steps_per_loop = params.trainer.steps_per_loop if (
execution_mode in ['accuracy', 'tflite_accuracy']) else 100
train_output_recorder = _OutputRecorderAction()
controller = orbit.Controller(
strategy=strategy,
trainer=trainer,
evaluator=trainer if (execution_mode == 'accuracy') else None,
train_actions=[train_output_recorder],
global_step=trainer.global_step,
steps_per_loop=steps_per_loop)
logging.info('Starts to execute execution mode: %s', execution_mode)
with strategy.scope():
# Training for one loop, first loop time includes warmup time.
first_loop_start_time = time.time()
controller.train(steps=steps_per_loop)
first_loop_time = time.time() - first_loop_start_time
# Training for second loop.
second_loop_start_time = time.time()
controller.train(steps=2*steps_per_loop)
second_loop_time = time.time() - second_loop_start_time
if execution_mode == 'accuracy':
controller.train(steps=params.trainer.train_steps)
wall_time = time.time() - first_loop_time
eval_logs = trainer.evaluate(
tf.convert_to_tensor(params.trainer.validation_steps))
benchmark_data = {'metrics': eval_logs}
elif execution_mode == 'performance':
if train_output_recorder.train_output:
benchmark_data = {'metrics': train_output_recorder.train_output}
else:
benchmark_data = {}
elif execution_mode == 'tflite_accuracy':
eval_logs = tflite_utils.train_and_evaluate(
params, task, trainer, controller)
benchmark_data = {'metrics': eval_logs}
else:
raise NotImplementedError(
'The benchmark execution mode is not implemented: %s' %
execution_mode)
# First training loop time contains startup time plus training time, while
# second training loop time is purely training time. Startup time can be
# recovered by subtracting second trianing loop time from first training
# loop time.
startup_time = first_loop_time - second_loop_time
wall_time = time.time() - first_loop_start_time
examples_per_second = steps_per_loop * params.task.train_data.global_batch_size / second_loop_time
benchmark_data.update(
dict(
examples_per_second=examples_per_second,
wall_time=wall_time,
startup_time=startup_time))
return benchmark_data
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow_models.official.benchmark.benchmark_lib."""
# pylint: disable=g-direct-tensorflow-import
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.common import registry_imports # pylint: disable=unused-import
from official.benchmark import benchmark_lib
from official.core import exp_factory
from official.modeling import hyperparams
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],)
class BenchmarkLibTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(BenchmarkLibTest, self).setUp()
self._test_config = {
'trainer': {
'steps_per_loop': 10,
'optimizer_config': {
'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
}
},
'continuous_eval_timeout': 5,
'train_steps': 20,
'validation_steps': 10
},
}
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
execution_mode=['performance', 'accuracy'],
))
def test_benchmark(self, distribution, execution_mode):
model_dir = self.get_temp_dir()
params = exp_factory.get_exp_config('mock')
params = hyperparams.override_params_dict(
params, self._test_config, is_strict=True)
benchmark_data = benchmark_lib.run_benchmark(execution_mode,
params,
model_dir,
distribution)
self.assertIn('examples_per_second', benchmark_data)
self.assertIn('wall_time', benchmark_data)
self.assertIn('startup_time', benchmark_data)
self.assertIn('metrics', benchmark_data)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment