{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Bp8t2AI8i7uP" }, "source": [ "##### Copyright 2022 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "rxPj2Lsni9O4" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "6xS-9i5DrRvO" }, "source": [ "# Customizing a Transformer Encoder" ] }, { "cell_type": "markdown", "metadata": { "id": "Mwb9uw1cDXsa" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "iLrcV4IyrcGX" }, "source": [ "## Learning objectives\n", "\n", "The [TensorFlow Models NLP library](https://github.com/tensorflow/models/tree/master/official/nlp/modeling) is a collection of tools for building and training modern high performance natural language models.\n", "\n", "The `tfm.nlp.networks.EncoderScaffold` is the core of this library, and lots of new network architectures are proposed to improve the encoder. In this Colab notebook, we will learn how to customize the encoder to employ new network architectures." ] }, { "cell_type": "markdown", "metadata": { "id": "YYxdyoWgsl8t" }, "source": [ "## Install and import" ] }, { "cell_type": "markdown", "metadata": { "id": "fEJSFutUsn_h" }, "source": [ "### Install the TensorFlow Model Garden pip package\n", "\n", "* `tf-models-official` is the stable Model Garden package. Note that it may not include the latest changes in the `tensorflow_models` github repo. To include latest changes, you may install `tf-models-nightly`,\n", "which is the nightly Model Garden package created daily automatically.\n", "* `pip` will install all models and dependencies automatically." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mfHI5JyuJ1y9" }, "outputs": [], "source": [ "!pip install -q opencv-python" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "thsKZDjhswhR" }, "outputs": [], "source": [ "!pip install -q tf-models-official" ] }, { "cell_type": "markdown", "metadata": { "id": "hpf7JPCVsqtv" }, "source": [ "### Import Tensorflow and other libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "my4dp-RMssQe" }, "outputs": [], "source": [ "import numpy as np\n", "import tensorflow as tf\n", "\n", "import tensorflow_models as tfm\n", "nlp = tfm.nlp" ] }, { "cell_type": "markdown", "metadata": { "id": "vjDmVsFfs85n" }, "source": [ "## Canonical BERT encoder\n", "\n", "Before learning how to customize the encoder, let's firstly create a canonical BERT enoder and use it to instantiate a `bert_classifier.BertClassifier` for classification task." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Oav8sbgstWc-" }, "outputs": [], "source": [ "cfg = {\n", " \"vocab_size\": 100,\n", " \"hidden_size\": 32,\n", " \"num_layers\": 3,\n", " \"num_attention_heads\": 4,\n", " \"intermediate_size\": 64,\n", " \"activation\": tfm.utils.activations.gelu,\n", " \"dropout_rate\": 0.1,\n", " \"attention_dropout_rate\": 0.1,\n", " \"max_sequence_length\": 16,\n", " \"type_vocab_size\": 2,\n", " \"initializer\": tf.keras.initializers.TruncatedNormal(stddev=0.02),\n", "}\n", "bert_encoder = nlp.networks.BertEncoder(**cfg)\n", "\n", "def build_classifier(bert_encoder):\n", " return nlp.models.BertClassifier(bert_encoder, num_classes=2)\n", "\n", "canonical_classifier_model = build_classifier(bert_encoder)" ] }, { "cell_type": "markdown", "metadata": { "id": "Qe2UWI6_tsHo" }, "source": [ "`canonical_classifier_model` can be trained using the training data. For details about how to train the model, please see the [Fine tuning bert](https://www.tensorflow.org/text/tutorials/fine_tune_bert) notebook. We skip the code that trains the model here.\n", "\n", "After training, we can apply the model to do prediction.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "csED2d-Yt5h6" }, "outputs": [], "source": [ "def predict(model):\n", " batch_size = 3\n", " np.random.seed(0)\n", " word_ids = np.random.randint(\n", " cfg[\"vocab_size\"], size=(batch_size, cfg[\"max_sequence_length\"]))\n", " mask = np.random.randint(2, size=(batch_size, cfg[\"max_sequence_length\"]))\n", " type_ids = np.random.randint(\n", " cfg[\"type_vocab_size\"], size=(batch_size, cfg[\"max_sequence_length\"]))\n", " print(model([word_ids, mask, type_ids], training=False))\n", "\n", "predict(canonical_classifier_model)" ] }, { "cell_type": "markdown", "metadata": { "id": "PzKStEK9t_Pb" }, "source": [ "## Customize BERT encoder\n", "\n", "One BERT encoder consists of an embedding network and multiple transformer blocks, and each transformer block contains an attention layer and a feedforward layer." ] }, { "cell_type": "markdown", "metadata": { "id": "rmwQfhj6fmKz" }, "source": [ "We provide easy ways to customize each of those components via (1)\n", "[EncoderScaffold](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/encoder_scaffold.py) and (2) [TransformerScaffold](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/transformer_scaffold.py)." ] }, { "cell_type": "markdown", "metadata": { "id": "xsMgEVHAui11" }, "source": [ "### Use EncoderScaffold\n", "\n", "`networks.EncoderScaffold` allows users to provide a custom embedding subnetwork\n", " (which will replace the standard embedding logic) and/or a custom hidden layer class (which will replace the `Transformer` instantiation in the encoder)." ] }, { "cell_type": "markdown", "metadata": { "id": "-JBabpa2AOz8" }, "source": [ "#### Without Customization\n", "\n", "Without any customization, `networks.EncoderScaffold` behaves the same the canonical `networks.BertEncoder`.\n", "\n", "As shown in the following example, `networks.EncoderScaffold` can load `networks.BertEncoder`'s weights and output the same values:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ktNzKuVByZQf" }, "outputs": [], "source": [ "default_hidden_cfg = dict(\n", " num_attention_heads=cfg[\"num_attention_heads\"],\n", " intermediate_size=cfg[\"intermediate_size\"],\n", " intermediate_activation=cfg[\"activation\"],\n", " dropout_rate=cfg[\"dropout_rate\"],\n", " attention_dropout_rate=cfg[\"attention_dropout_rate\"],\n", " kernel_initializer=cfg[\"initializer\"],\n", ")\n", "default_embedding_cfg = dict(\n", " vocab_size=cfg[\"vocab_size\"],\n", " type_vocab_size=cfg[\"type_vocab_size\"],\n", " hidden_size=cfg[\"hidden_size\"],\n", " initializer=cfg[\"initializer\"],\n", " dropout_rate=cfg[\"dropout_rate\"],\n", " max_seq_length=cfg[\"max_sequence_length\"]\n", ")\n", "default_kwargs = dict(\n", " hidden_cfg=default_hidden_cfg,\n", " embedding_cfg=default_embedding_cfg,\n", " num_hidden_instances=cfg[\"num_layers\"],\n", " pooled_output_dim=cfg[\"hidden_size\"],\n", " return_all_layer_outputs=True,\n", " pooler_layer_initializer=cfg[\"initializer\"],\n", ")\n", "\n", "encoder_scaffold = nlp.networks.EncoderScaffold(**default_kwargs)\n", "classifier_model_from_encoder_scaffold = build_classifier(encoder_scaffold)\n", "classifier_model_from_encoder_scaffold.set_weights(\n", " canonical_classifier_model.get_weights())\n", "predict(classifier_model_from_encoder_scaffold)" ] }, { "cell_type": "markdown", "metadata": { "id": "sMaUmLyIuwcs" }, "source": [ "#### Customize Embedding\n", "\n", "Next, we show how to use a customized embedding network.\n", "\n", "We firstly build an embedding network that will replace the default network. This one will have 2 inputs (`mask` and `word_ids`) instead of 3, and won't use positional embeddings." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LTinnaG6vcsw" }, "outputs": [], "source": [ "word_ids = tf.keras.layers.Input(\n", " shape=(cfg['max_sequence_length'],), dtype=tf.int32, name=\"input_word_ids\")\n", "mask = tf.keras.layers.Input(\n", " shape=(cfg['max_sequence_length'],), dtype=tf.int32, name=\"input_mask\")\n", "embedding_layer = nlp.layers.OnDeviceEmbedding(\n", " vocab_size=cfg['vocab_size'],\n", " embedding_width=cfg['hidden_size'],\n", " initializer=cfg[\"initializer\"],\n", " name=\"word_embeddings\")\n", "word_embeddings = embedding_layer(word_ids)\n", "attention_mask = nlp.layers.SelfAttentionMask()([word_embeddings, mask])\n", "new_embedding_network = tf.keras.Model([word_ids, mask],\n", " [word_embeddings, attention_mask])" ] }, { "cell_type": "markdown", "metadata": { "id": "HN7_yu-6O3qI" }, "source": [ "Inspecting `new_embedding_network`, we can see it takes two inputs:\n", "`input_word_ids` and `input_mask`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fO9zKFE4OpHp" }, "outputs": [], "source": [ "tf.keras.utils.plot_model(new_embedding_network, show_shapes=True, dpi=48)" ] }, { "cell_type": "markdown", "metadata": { "id": "9cOaGQHLv12W" }, "source": [ "We then can build a new encoder using the above `new_embedding_network`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mtFDMNf2vIl9" }, "outputs": [], "source": [ "kwargs = dict(default_kwargs)\n", "\n", "# Use new embedding network.\n", "kwargs['embedding_cls'] = new_embedding_network\n", "kwargs['embedding_data'] = embedding_layer.embeddings\n", "\n", "encoder_with_customized_embedding = nlp.networks.EncoderScaffold(**kwargs)\n", "classifier_model = build_classifier(encoder_with_customized_embedding)\n", "# ... Train the model ...\n", "print(classifier_model.inputs)\n", "\n", "# Assert that there are only two inputs.\n", "assert len(classifier_model.inputs) == 2" ] }, { "cell_type": "markdown", "metadata": { "id": "Z73ZQDtmwg9K" }, "source": [ "#### Customized Transformer\n", "\n", "User can also override the `hidden_cls` argument in `networks.EncoderScaffold`'s constructor to employ a customized Transformer layer.\n", "\n", "See [the source of `nlp.layers.ReZeroTransformer`](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/rezero_transformer.py) for how to implement a customized Transformer layer.\n", "\n", "Following is an example of using `nlp.layers.ReZeroTransformer`:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uAIarLZgw6pA" }, "outputs": [], "source": [ "kwargs = dict(default_kwargs)\n", "\n", "# Use ReZeroTransformer.\n", "kwargs['hidden_cls'] = nlp.layers.ReZeroTransformer\n", "\n", "encoder_with_rezero_transformer = nlp.networks.EncoderScaffold(**kwargs)\n", "classifier_model = build_classifier(encoder_with_rezero_transformer)\n", "# ... Train the model ...\n", "predict(classifier_model)\n", "\n", "# Assert that the variable `rezero_alpha` from ReZeroTransformer exists.\n", "assert 'rezero_alpha' in ''.join([x.name for x in classifier_model.trainable_weights])" ] }, { "cell_type": "markdown", "metadata": { "id": "6PMHFdvnxvR0" }, "source": [ "### Use `nlp.layers.TransformerScaffold`\n", "\n", "The above method of customizing the model requires rewriting the whole `nlp.layers.Transformer` layer, while sometimes you may only want to customize either attention layer or feedforward block. In this case, `nlp.layers.TransformerScaffold` can be used.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "D6FejlgwyAy_" }, "source": [ "#### Customize Attention Layer\n", "\n", "User can also override the `attention_cls` argument in `layers.TransformerScaffold`'s constructor to employ a customized Attention layer.\n", "\n", "See [the source of `nlp.layers.TalkingHeadsAttention`](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/talking_heads_attention.py) for how to implement a customized `Attention` layer.\n", "\n", "Following is an example of using `nlp.layers.TalkingHeadsAttention`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nFrSMrZuyNeQ" }, "outputs": [], "source": [ "# Use TalkingHeadsAttention\n", "hidden_cfg = dict(default_hidden_cfg)\n", "hidden_cfg['attention_cls'] = nlp.layers.TalkingHeadsAttention\n", "\n", "kwargs = dict(default_kwargs)\n", "kwargs['hidden_cls'] = nlp.layers.TransformerScaffold\n", "kwargs['hidden_cfg'] = hidden_cfg\n", "\n", "encoder = nlp.networks.EncoderScaffold(**kwargs)\n", "classifier_model = build_classifier(encoder)\n", "# ... Train the model ...\n", "predict(classifier_model)\n", "\n", "# Assert that the variable `pre_softmax_weight` from TalkingHeadsAttention exists.\n", "assert 'pre_softmax_weight' in ''.join([x.name for x in classifier_model.trainable_weights])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tKkZ8spzYmpc" }, "outputs": [], "source": [ "tf.keras.utils.plot_model(encoder_with_rezero_transformer, show_shapes=True, dpi=48)" ] }, { "cell_type": "markdown", "metadata": { "id": "kuEJcTyByVvI" }, "source": [ "#### Customize Feedforward Layer\n", "\n", "Similiarly, one could also customize the feedforward layer.\n", "\n", "See [the source of `nlp.layers.GatedFeedforward`](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/gated_feedforward.py) for how to implement a customized feedforward layer.\n", "\n", "Following is an example of using `nlp.layers.GatedFeedforward`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XAbKy_l4y_-i" }, "outputs": [], "source": [ "# Use GatedFeedforward\n", "hidden_cfg = dict(default_hidden_cfg)\n", "hidden_cfg['feedforward_cls'] = nlp.layers.GatedFeedforward\n", "\n", "kwargs = dict(default_kwargs)\n", "kwargs['hidden_cls'] = nlp.layers.TransformerScaffold\n", "kwargs['hidden_cfg'] = hidden_cfg\n", "\n", "encoder_with_gated_feedforward = nlp.networks.EncoderScaffold(**kwargs)\n", "classifier_model = build_classifier(encoder_with_gated_feedforward)\n", "# ... Train the model ...\n", "predict(classifier_model)\n", "\n", "# Assert that the variable `gate` from GatedFeedforward exists.\n", "assert 'gate' in ''.join([x.name for x in classifier_model.trainable_weights])" ] }, { "cell_type": "markdown", "metadata": { "id": "a_8NWUhkzeAq" }, "source": [ "### Build a new Encoder\n", "\n", "Finally, you could also build a new encoder using building blocks in the modeling library.\n", "\n", "See [the source for `nlp.networks.AlbertEncoder`](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/albert_encoder.py) as an example of how to do this. \n", "\n", "Here is an example using `nlp.networks.AlbertEncoder`:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xsiA3RzUzmUM" }, "outputs": [], "source": [ "albert_encoder = nlp.networks.AlbertEncoder(**cfg)\n", "classifier_model = build_classifier(albert_encoder)\n", "# ... Train the model ...\n", "predict(classifier_model)" ] }, { "cell_type": "markdown", "metadata": { "id": "MeidDfhlHKSO" }, "source": [ "Inspecting the `albert_encoder`, we see it stacks the same `Transformer` layer multiple times (note the loop-back on the \"Transformer\" block below.." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Uv_juT22HERW" }, "outputs": [], "source": [ "tf.keras.utils.plot_model(albert_encoder, show_shapes=True, dpi=48)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "customize_encoder.ipynb", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }