Commit 537aaad5 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Update the other two colabs customize_encoder.ipynb and...

Update the other two colabs customize_encoder.ipynb and nlp_modeling_library_intro.ipynb to use the latest tensorflow models pip package.

PiperOrigin-RevId: 357275077
parent 838339f6
{ {
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Customizing a Transformer Encoder",
"private_outputs": true,
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "Bp8t2AI8i7uP" "id": "Bp8t2AI8i7uP"
}, },
"source": [ "source": [
...@@ -12,14 +26,10 @@ ...@@ -12,14 +26,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"cellView": "form", "cellView": "form",
"colab": {},
"colab_type": "code",
"id": "rxPj2Lsni9O4" "id": "rxPj2Lsni9O4"
}, },
"outputs": [],
"source": [ "source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n", "# you may not use this file except in compliance with the License.\n",
...@@ -32,12 +42,13 @@ ...@@ -32,12 +42,13 @@
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n", "# See the License for the specific language governing permissions and\n",
"# limitations under the License." "# limitations under the License."
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "6xS-9i5DrRvO" "id": "6xS-9i5DrRvO"
}, },
"source": [ "source": [
...@@ -47,30 +58,28 @@ ...@@ -47,30 +58,28 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "Mwb9uw1cDXsa" "id": "Mwb9uw1cDXsa"
}, },
"source": [ "source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
" \u003ctd\u003e\n", " <td>\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/official_models/nlp/customize_encoder\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", " <a target=\"_blank\" href=\"https://www.tensorflow.org/official_models/nlp/customize_encoder\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
" \u003c/td\u003e\n", " </td>\n",
" \u003ctd\u003e\n", " <td>\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/colab/nlp/customize_encoder.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", " <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/colab/nlp/customize_encoder.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" \u003c/td\u003e\n", " </td>\n",
" \u003ctd\u003e\n", " <td>\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/colab/nlp/customize_encoder.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", " <a target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/colab/nlp/customize_encoder.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
" \u003c/td\u003e\n", " </td>\n",
" \u003ctd\u003e\n", " <td>\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/official/colab/nlp/customize_encoder.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", " <a href=\"https://storage.googleapis.com/tensorflow_docs/models/official/colab/nlp/customize_encoder.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
" \u003c/td\u003e\n", " </td>\n",
"\u003c/table\u003e" "</table>"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "iLrcV4IyrcGX" "id": "iLrcV4IyrcGX"
}, },
"source": [ "source": [
...@@ -84,7 +93,6 @@ ...@@ -84,7 +93,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "YYxdyoWgsl8t" "id": "YYxdyoWgsl8t"
}, },
"source": [ "source": [
...@@ -94,7 +102,6 @@ ...@@ -94,7 +102,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "fEJSFutUsn_h" "id": "fEJSFutUsn_h"
}, },
"source": [ "source": [
...@@ -107,21 +114,18 @@ ...@@ -107,21 +114,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "thsKZDjhswhR" "id": "thsKZDjhswhR"
}, },
"outputs": [],
"source": [ "source": [
"!pip install -q tf-models-official==2.3.0" "!pip install -q tf-models-official==2.4.0"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "hpf7JPCVsqtv" "id": "hpf7JPCVsqtv"
}, },
"source": [ "source": [
...@@ -130,13 +134,9 @@ ...@@ -130,13 +134,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "my4dp-RMssQe" "id": "my4dp-RMssQe"
}, },
"outputs": [],
"source": [ "source": [
"import numpy as np\n", "import numpy as np\n",
"import tensorflow as tf\n", "import tensorflow as tf\n",
...@@ -144,12 +144,13 @@ ...@@ -144,12 +144,13 @@
"from official.modeling import activations\n", "from official.modeling import activations\n",
"from official.nlp import modeling\n", "from official.nlp import modeling\n",
"from official.nlp.modeling import layers, losses, models, networks" "from official.nlp.modeling import layers, losses, models, networks"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "vjDmVsFfs85n" "id": "vjDmVsFfs85n"
}, },
"source": [ "source": [
...@@ -160,13 +161,9 @@ ...@@ -160,13 +161,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "Oav8sbgstWc-" "id": "Oav8sbgstWc-"
}, },
"outputs": [],
"source": [ "source": [
"cfg = {\n", "cfg = {\n",
" \"vocab_size\": 100,\n", " \"vocab_size\": 100,\n",
...@@ -177,22 +174,23 @@ ...@@ -177,22 +174,23 @@
" \"activation\": activations.gelu,\n", " \"activation\": activations.gelu,\n",
" \"dropout_rate\": 0.1,\n", " \"dropout_rate\": 0.1,\n",
" \"attention_dropout_rate\": 0.1,\n", " \"attention_dropout_rate\": 0.1,\n",
" \"sequence_length\": 16,\n", " \"max_sequence_length\": 16,\n",
" \"type_vocab_size\": 2,\n", " \"type_vocab_size\": 2,\n",
" \"initializer\": tf.keras.initializers.TruncatedNormal(stddev=0.02),\n", " \"initializer\": tf.keras.initializers.TruncatedNormal(stddev=0.02),\n",
"}\n", "}\n",
"bert_encoder = modeling.networks.TransformerEncoder(**cfg)\n", "bert_encoder = modeling.networks.BertEncoder(**cfg)\n",
"\n", "\n",
"def build_classifier(bert_encoder):\n", "def build_classifier(bert_encoder):\n",
" return modeling.models.BertClassifier(bert_encoder, num_classes=2)\n", " return modeling.models.BertClassifier(bert_encoder, num_classes=2)\n",
"\n", "\n",
"canonical_classifier_model = build_classifier(bert_encoder)" "canonical_classifier_model = build_classifier(bert_encoder)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "Qe2UWI6_tsHo" "id": "Qe2UWI6_tsHo"
}, },
"source": [ "source": [
...@@ -203,31 +201,28 @@ ...@@ -203,31 +201,28 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "csED2d-Yt5h6" "id": "csED2d-Yt5h6"
}, },
"outputs": [],
"source": [ "source": [
"def predict(model):\n", "def predict(model):\n",
" batch_size = 3\n", " batch_size = 3\n",
" np.random.seed(0)\n", " np.random.seed(0)\n",
" word_ids = np.random.randint(\n", " word_ids = np.random.randint(\n",
" cfg[\"vocab_size\"], size=(batch_size, cfg[\"sequence_length\"]))\n", " cfg[\"vocab_size\"], size=(batch_size, cfg[\"max_sequence_length\"]))\n",
" mask = np.random.randint(2, size=(batch_size, cfg[\"sequence_length\"]))\n", " mask = np.random.randint(2, size=(batch_size, cfg[\"max_sequence_length\"]))\n",
" type_ids = np.random.randint(\n", " type_ids = np.random.randint(\n",
" cfg[\"type_vocab_size\"], size=(batch_size, cfg[\"sequence_length\"]))\n", " cfg[\"type_vocab_size\"], size=(batch_size, cfg[\"max_sequence_length\"]))\n",
" print(model([word_ids, mask, type_ids], training=False))\n", " print(model([word_ids, mask, type_ids], training=False))\n",
"\n", "\n",
"predict(canonical_classifier_model)" "predict(canonical_classifier_model)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "PzKStEK9t_Pb" "id": "PzKStEK9t_Pb"
}, },
"source": [ "source": [
...@@ -239,7 +234,6 @@ ...@@ -239,7 +234,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "rmwQfhj6fmKz" "id": "rmwQfhj6fmKz"
}, },
"source": [ "source": [
...@@ -250,7 +244,6 @@ ...@@ -250,7 +244,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "xsMgEVHAui11" "id": "xsMgEVHAui11"
}, },
"source": [ "source": [
...@@ -263,26 +256,21 @@ ...@@ -263,26 +256,21 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "-JBabpa2AOz8" "id": "-JBabpa2AOz8"
}, },
"source": [ "source": [
"#### Without Customization\n", "#### Without Customization\n",
"\n", "\n",
"Without any customization, `EncoderScaffold` behaves the same the canonical `TransformerEncoder`.\n", "Without any customization, `EncoderScaffold` behaves the same the canonical `BertEncoder`.\n",
"\n", "\n",
"As shown in the following example, `EncoderScaffold` can load `TransformerEncoder`'s weights and output the same values:" "As shown in the following example, `EncoderScaffold` can load `BertEncoder`'s weights and output the same values:"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "ktNzKuVByZQf" "id": "ktNzKuVByZQf"
}, },
"outputs": [],
"source": [ "source": [
"default_hidden_cfg = dict(\n", "default_hidden_cfg = dict(\n",
" num_attention_heads=cfg[\"num_attention_heads\"],\n", " num_attention_heads=cfg[\"num_attention_heads\"],\n",
...@@ -296,10 +284,9 @@ ...@@ -296,10 +284,9 @@
" vocab_size=cfg[\"vocab_size\"],\n", " vocab_size=cfg[\"vocab_size\"],\n",
" type_vocab_size=cfg[\"type_vocab_size\"],\n", " type_vocab_size=cfg[\"type_vocab_size\"],\n",
" hidden_size=cfg[\"hidden_size\"],\n", " hidden_size=cfg[\"hidden_size\"],\n",
" seq_length=cfg[\"sequence_length\"],\n",
" initializer=tf.keras.initializers.TruncatedNormal(0.02),\n", " initializer=tf.keras.initializers.TruncatedNormal(0.02),\n",
" dropout_rate=cfg[\"dropout_rate\"],\n", " dropout_rate=cfg[\"dropout_rate\"],\n",
" max_seq_length=cfg[\"sequence_length\"],\n", " max_seq_length=cfg[\"max_sequence_length\"]\n",
")\n", ")\n",
"default_kwargs = dict(\n", "default_kwargs = dict(\n",
" hidden_cfg=default_hidden_cfg,\n", " hidden_cfg=default_hidden_cfg,\n",
...@@ -309,17 +296,19 @@ ...@@ -309,17 +296,19 @@
" return_all_layer_outputs=True,\n", " return_all_layer_outputs=True,\n",
" pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(0.02),\n", " pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(0.02),\n",
")\n", ")\n",
"\n",
"encoder_scaffold = modeling.networks.EncoderScaffold(**default_kwargs)\n", "encoder_scaffold = modeling.networks.EncoderScaffold(**default_kwargs)\n",
"classifier_model_from_encoder_scaffold = build_classifier(encoder_scaffold)\n", "classifier_model_from_encoder_scaffold = build_classifier(encoder_scaffold)\n",
"classifier_model_from_encoder_scaffold.set_weights(\n", "classifier_model_from_encoder_scaffold.set_weights(\n",
" canonical_classifier_model.get_weights())\n", " canonical_classifier_model.get_weights())\n",
"predict(classifier_model_from_encoder_scaffold)" "predict(classifier_model_from_encoder_scaffold)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "sMaUmLyIuwcs" "id": "sMaUmLyIuwcs"
}, },
"source": [ "source": [
...@@ -332,18 +321,14 @@ ...@@ -332,18 +321,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "LTinnaG6vcsw" "id": "LTinnaG6vcsw"
}, },
"outputs": [],
"source": [ "source": [
"word_ids = tf.keras.layers.Input(\n", "word_ids = tf.keras.layers.Input(\n",
" shape=(cfg['sequence_length'],), dtype=tf.int32, name=\"input_word_ids\")\n", " shape=(cfg['max_sequence_length'],), dtype=tf.int32, name=\"input_word_ids\")\n",
"mask = tf.keras.layers.Input(\n", "mask = tf.keras.layers.Input(\n",
" shape=(cfg['sequence_length'],), dtype=tf.int32, name=\"input_mask\")\n", " shape=(cfg['max_sequence_length'],), dtype=tf.int32, name=\"input_mask\")\n",
"embedding_layer = modeling.layers.OnDeviceEmbedding(\n", "embedding_layer = modeling.layers.OnDeviceEmbedding(\n",
" vocab_size=cfg['vocab_size'],\n", " vocab_size=cfg['vocab_size'],\n",
" embedding_width=cfg['hidden_size'],\n", " embedding_width=cfg['hidden_size'],\n",
...@@ -353,12 +338,13 @@ ...@@ -353,12 +338,13 @@
"attention_mask = layers.SelfAttentionMask()([word_embeddings, mask])\n", "attention_mask = layers.SelfAttentionMask()([word_embeddings, mask])\n",
"new_embedding_network = tf.keras.Model([word_ids, mask],\n", "new_embedding_network = tf.keras.Model([word_ids, mask],\n",
" [word_embeddings, attention_mask])" " [word_embeddings, attention_mask])"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "HN7_yu-6O3qI" "id": "HN7_yu-6O3qI"
}, },
"source": [ "source": [
...@@ -368,21 +354,18 @@ ...@@ -368,21 +354,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "fO9zKFE4OpHp" "id": "fO9zKFE4OpHp"
}, },
"outputs": [],
"source": [ "source": [
"tf.keras.utils.plot_model(new_embedding_network, show_shapes=True, dpi=48)" "tf.keras.utils.plot_model(new_embedding_network, show_shapes=True, dpi=48)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "9cOaGQHLv12W" "id": "9cOaGQHLv12W"
}, },
"source": [ "source": [
...@@ -391,13 +374,9 @@ ...@@ -391,13 +374,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "mtFDMNf2vIl9" "id": "mtFDMNf2vIl9"
}, },
"outputs": [],
"source": [ "source": [
"kwargs = dict(default_kwargs)\n", "kwargs = dict(default_kwargs)\n",
"\n", "\n",
...@@ -412,12 +391,13 @@ ...@@ -412,12 +391,13 @@
"\n", "\n",
"# Assert that there are only two inputs.\n", "# Assert that there are only two inputs.\n",
"assert len(classifier_model.inputs) == 2" "assert len(classifier_model.inputs) == 2"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "Z73ZQDtmwg9K" "id": "Z73ZQDtmwg9K"
}, },
"source": [ "source": [
...@@ -432,13 +412,9 @@ ...@@ -432,13 +412,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "uAIarLZgw6pA" "id": "uAIarLZgw6pA"
}, },
"outputs": [],
"source": [ "source": [
"kwargs = dict(default_kwargs)\n", "kwargs = dict(default_kwargs)\n",
"\n", "\n",
...@@ -452,12 +428,13 @@ ...@@ -452,12 +428,13 @@
"\n", "\n",
"# Assert that the variable `rezero_alpha` from ReZeroTransformer exists.\n", "# Assert that the variable `rezero_alpha` from ReZeroTransformer exists.\n",
"assert 'rezero_alpha' in ''.join([x.name for x in classifier_model.trainable_weights])" "assert 'rezero_alpha' in ''.join([x.name for x in classifier_model.trainable_weights])"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "6PMHFdvnxvR0" "id": "6PMHFdvnxvR0"
}, },
"source": [ "source": [
...@@ -470,7 +447,6 @@ ...@@ -470,7 +447,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "D6FejlgwyAy_" "id": "D6FejlgwyAy_"
}, },
"source": [ "source": [
...@@ -485,13 +461,9 @@ ...@@ -485,13 +461,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "nFrSMrZuyNeQ" "id": "nFrSMrZuyNeQ"
}, },
"outputs": [],
"source": [ "source": [
"# Use TalkingHeadsAttention\n", "# Use TalkingHeadsAttention\n",
"hidden_cfg = dict(default_hidden_cfg)\n", "hidden_cfg = dict(default_hidden_cfg)\n",
...@@ -508,12 +480,13 @@ ...@@ -508,12 +480,13 @@
"\n", "\n",
"# Assert that the variable `pre_softmax_weight` from TalkingHeadsAttention exists.\n", "# Assert that the variable `pre_softmax_weight` from TalkingHeadsAttention exists.\n",
"assert 'pre_softmax_weight' in ''.join([x.name for x in classifier_model.trainable_weights])" "assert 'pre_softmax_weight' in ''.join([x.name for x in classifier_model.trainable_weights])"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "kuEJcTyByVvI" "id": "kuEJcTyByVvI"
}, },
"source": [ "source": [
...@@ -528,13 +501,9 @@ ...@@ -528,13 +501,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "XAbKy_l4y_-i" "id": "XAbKy_l4y_-i"
}, },
"outputs": [],
"source": [ "source": [
"# Use TalkingHeadsAttention\n", "# Use TalkingHeadsAttention\n",
"hidden_cfg = dict(default_hidden_cfg)\n", "hidden_cfg = dict(default_hidden_cfg)\n",
...@@ -551,12 +520,13 @@ ...@@ -551,12 +520,13 @@
"\n", "\n",
"# Assert that the variable `gate` from GatedFeedforward exists.\n", "# Assert that the variable `gate` from GatedFeedforward exists.\n",
"assert 'gate' in ''.join([x.name for x in classifier_model.trainable_weights])" "assert 'gate' in ''.join([x.name for x in classifier_model.trainable_weights])"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "a_8NWUhkzeAq" "id": "a_8NWUhkzeAq"
}, },
"source": [ "source": [
...@@ -564,29 +534,26 @@ ...@@ -564,29 +534,26 @@
"\n", "\n",
"Finally, you could also build a new encoder using building blocks in the modeling library.\n", "Finally, you could also build a new encoder using building blocks in the modeling library.\n",
"\n", "\n",
"See [AlbertTransformerEncoder](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/albert_transformer_encoder.py) as an example:\n" "See [AlbertEncoder](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/albert_encoder.py) as an example:\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "xsiA3RzUzmUM" "id": "xsiA3RzUzmUM"
}, },
"outputs": [],
"source": [ "source": [
"albert_encoder = modeling.networks.AlbertTransformerEncoder(**cfg)\n", "albert_encoder = modeling.networks.AlbertEncoder(**cfg)\n",
"classifier_model = build_classifier(albert_encoder)\n", "classifier_model = build_classifier(albert_encoder)\n",
"# ... Train the model ...\n", "# ... Train the model ...\n",
"predict(classifier_model)" "predict(classifier_model)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "MeidDfhlHKSO" "id": "MeidDfhlHKSO"
}, },
"source": [ "source": [
...@@ -595,31 +562,14 @@ ...@@ -595,31 +562,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "Uv_juT22HERW" "id": "Uv_juT22HERW"
}, },
"outputs": [],
"source": [ "source": [
"tf.keras.utils.plot_model(albert_encoder, show_shapes=True, dpi=48)" "tf.keras.utils.plot_model(albert_encoder, show_shapes=True, dpi=48)"
] ],
} "execution_count": null,
], "outputs": []
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Customizing a Transformer Encoder",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
} }
}, ]
"nbformat": 4, }
"nbformat_minor": 0 \ No newline at end of file
}
{ {
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Introduction to the TensorFlow Models NLP library",
"private_outputs": true,
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "80xnUmoI7fBX" "id": "80xnUmoI7fBX"
}, },
"source": [ "source": [
...@@ -12,14 +26,10 @@ ...@@ -12,14 +26,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"cellView": "form", "cellView": "form",
"colab": {},
"colab_type": "code",
"id": "8nvTnfs6Q692" "id": "8nvTnfs6Q692"
}, },
"outputs": [],
"source": [ "source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n", "# you may not use this file except in compliance with the License.\n",
...@@ -32,12 +42,13 @@ ...@@ -32,12 +42,13 @@
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n", "# See the License for the specific language governing permissions and\n",
"# limitations under the License." "# limitations under the License."
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "WmfcMK5P5C1G" "id": "WmfcMK5P5C1G"
}, },
"source": [ "source": [
...@@ -47,30 +58,28 @@ ...@@ -47,30 +58,28 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "cH-oJ8R6AHMK" "id": "cH-oJ8R6AHMK"
}, },
"source": [ "source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
" \u003ctd\u003e\n", " <td>\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/official_models/nlp/nlp_modeling_library_intro\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", " <a target=\"_blank\" href=\"https://www.tensorflow.org/official_models/nlp/nlp_modeling_library_intro\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
" \u003c/td\u003e\n", " </td>\n",
" \u003ctd\u003e\n", " <td>\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", " <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" \u003c/td\u003e\n", " </td>\n",
" \u003ctd\u003e\n", " <td>\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", " <a target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
" \u003c/td\u003e\n", " </td>\n",
" \u003ctd\u003e\n", " <td>\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/official/colab/nlp/nlp_modeling_library_intro.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", " <a href=\"https://storage.googleapis.com/tensorflow_docs/models/official/colab/nlp/nlp_modeling_library_intro.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
" \u003c/td\u003e\n", " </td>\n",
"\u003c/table\u003e" "</table>"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "0H_EFIhq4-MJ" "id": "0H_EFIhq4-MJ"
}, },
"source": [ "source": [
...@@ -82,7 +91,6 @@ ...@@ -82,7 +91,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "2N97-dps_nUk" "id": "2N97-dps_nUk"
}, },
"source": [ "source": [
...@@ -92,7 +100,6 @@ ...@@ -92,7 +100,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "459ygAVl_rg0" "id": "459ygAVl_rg0"
}, },
"source": [ "source": [
...@@ -105,21 +112,18 @@ ...@@ -105,21 +112,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "Y-qGkdh6_sZc" "id": "Y-qGkdh6_sZc"
}, },
"outputs": [],
"source": [ "source": [
"!pip install -q tf-models-official==2.3.0" "!pip install -q tf-models-official==2.4.0"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "e4huSSwyAG_5" "id": "e4huSSwyAG_5"
}, },
"source": [ "source": [
...@@ -128,25 +132,22 @@ ...@@ -128,25 +132,22 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "jqYXqtjBAJd9" "id": "jqYXqtjBAJd9"
}, },
"outputs": [],
"source": [ "source": [
"import numpy as np\n", "import numpy as np\n",
"import tensorflow as tf\n", "import tensorflow as tf\n",
"\n", "\n",
"from official.nlp import modeling\n", "from official.nlp import modeling\n",
"from official.nlp.modeling import layers, losses, models, networks" "from official.nlp.modeling import layers, losses, models, networks"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "djBQWjvy-60Y" "id": "djBQWjvy-60Y"
}, },
"source": [ "source": [
...@@ -160,38 +161,34 @@ ...@@ -160,38 +161,34 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "MKuHVlsCHmiq" "id": "MKuHVlsCHmiq"
}, },
"source": [ "source": [
"### Build a `BertPretrainer` model wrapping `TransformerEncoder`\n", "### Build a `BertPretrainer` model wrapping `BertEncoder`\n",
"\n", "\n",
"The [TransformerEncoder](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/transformer_encoder.py) implements the Transformer-based encoder as described in [BERT paper](https://arxiv.org/abs/1810.04805). It includes the embedding lookups and transformer layers, but not the masked language model or classification task networks.\n", "The [BertEncoder](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/bert_encoder.py) implements the Transformer-based encoder as described in [BERT paper](https://arxiv.org/abs/1810.04805). It includes the embedding lookups and transformer layers, but not the masked language model or classification task networks.\n",
"\n", "\n",
"The [BertPretrainer](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_pretrainer.py) allows a user to pass in a transformer stack, and instantiates the masked language model and classification networks that are used to create the training objectives." "The [BertPretrainer](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_pretrainer.py) allows a user to pass in a transformer stack, and instantiates the masked language model and classification networks that are used to create the training objectives."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "EXkcXz-9BwB3" "id": "EXkcXz-9BwB3"
}, },
"outputs": [],
"source": [ "source": [
"# Build a small transformer network.\n", "# Build a small transformer network.\n",
"vocab_size = 100\n", "vocab_size = 100\n",
"sequence_length = 16\n", "sequence_length = 16\n",
"network = modeling.networks.TransformerEncoder(\n", "network = modeling.networks.BertEncoder(\n",
" vocab_size=vocab_size, num_layers=2, sequence_length=16)" " vocab_size=vocab_size, num_layers=2, sequence_length=16)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "0NH5irV5KTMS" "id": "0NH5irV5KTMS"
}, },
"source": [ "source": [
...@@ -202,37 +199,32 @@ ...@@ -202,37 +199,32 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "lZNoZkBrIoff" "id": "lZNoZkBrIoff"
}, },
"outputs": [],
"source": [ "source": [
"tf.keras.utils.plot_model(network, show_shapes=True, dpi=48)" "tf.keras.utils.plot_model(network, show_shapes=True, dpi=48)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "o7eFOZXiIl-b" "id": "o7eFOZXiIl-b"
}, },
"outputs": [],
"source": [ "source": [
"# Create a BERT pretrainer with the created network.\n", "# Create a BERT pretrainer with the created network.\n",
"num_token_predictions = 8\n", "num_token_predictions = 8\n",
"bert_pretrainer = modeling.models.BertPretrainer(\n", "bert_pretrainer = modeling.models.BertPretrainer(\n",
" network, num_classes=2, num_token_predictions=num_token_predictions, output='predictions')" " network, num_classes=2, num_token_predictions=num_token_predictions, output='predictions')"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "d5h5HT7gNHx_" "id": "d5h5HT7gNHx_"
}, },
"source": [ "source": [
...@@ -241,26 +233,20 @@ ...@@ -241,26 +233,20 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "2tcNfm03IBF7" "id": "2tcNfm03IBF7"
}, },
"outputs": [],
"source": [ "source": [
"tf.keras.utils.plot_model(bert_pretrainer, show_shapes=True, dpi=48)" "tf.keras.utils.plot_model(bert_pretrainer, show_shapes=True, dpi=48)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "F2oHrXGUIS0M" "id": "F2oHrXGUIS0M"
}, },
"outputs": [],
"source": [ "source": [
"# We can feed some dummy data to get masked language model and sentence output.\n", "# We can feed some dummy data to get masked language model and sentence output.\n",
"batch_size = 2\n", "batch_size = 2\n",
...@@ -275,12 +261,13 @@ ...@@ -275,12 +261,13 @@
"sentence_output = outputs[\"classification\"]\n", "sentence_output = outputs[\"classification\"]\n",
"print(lm_output)\n", "print(lm_output)\n",
"print(sentence_output)" "print(sentence_output)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "bnx3UCHniCS5" "id": "bnx3UCHniCS5"
}, },
"source": [ "source": [
...@@ -290,13 +277,9 @@ ...@@ -290,13 +277,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "k30H4Q86f52x" "id": "k30H4Q86f52x"
}, },
"outputs": [],
"source": [ "source": [
"masked_lm_ids_data = np.random.randint(vocab_size, size=(batch_size, num_token_predictions))\n", "masked_lm_ids_data = np.random.randint(vocab_size, size=(batch_size, num_token_predictions))\n",
"masked_lm_weights_data = np.random.randint(2, size=(batch_size, num_token_predictions))\n", "masked_lm_weights_data = np.random.randint(2, size=(batch_size, num_token_predictions))\n",
...@@ -311,12 +294,13 @@ ...@@ -311,12 +294,13 @@
" predictions=sentence_output)\n", " predictions=sentence_output)\n",
"loss = mlm_loss + sentence_loss\n", "loss = mlm_loss + sentence_loss\n",
"print(loss)" "print(loss)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "wrmSs8GjHxVw" "id": "wrmSs8GjHxVw"
}, },
"source": [ "source": [
...@@ -328,7 +312,6 @@ ...@@ -328,7 +312,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "k8cQVFvBCV4s" "id": "k8cQVFvBCV4s"
}, },
"source": [ "source": [
...@@ -342,38 +325,34 @@ ...@@ -342,38 +325,34 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "xrLLEWpfknUW" "id": "xrLLEWpfknUW"
}, },
"source": [ "source": [
"### Build a BertSpanLabeler wrapping TransformerEncoder\n", "### Build a BertSpanLabeler wrapping BertEncoder\n",
"\n", "\n",
"[BertSpanLabeler](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_span_labeler.py) implements a simple single-span start-end predictor (that is, a model that predicts two values: a start token index and an end token index), suitable for SQuAD-style tasks.\n", "[BertSpanLabeler](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_span_labeler.py) implements a simple single-span start-end predictor (that is, a model that predicts two values: a start token index and an end token index), suitable for SQuAD-style tasks.\n",
"\n", "\n",
"Note that `BertSpanLabeler` wraps a `TransformerEncoder`, the weights of which can be restored from the above pretraining model.\n" "Note that `BertSpanLabeler` wraps a `BertEncoder`, the weights of which can be restored from the above pretraining model.\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "B941M4iUCejO" "id": "B941M4iUCejO"
}, },
"outputs": [],
"source": [ "source": [
"network = modeling.networks.TransformerEncoder(\n", "network = modeling.networks.BertEncoder(\n",
" vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)\n", " vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)\n",
"\n", "\n",
"# Create a BERT trainer with the created network.\n", "# Create a BERT trainer with the created network.\n",
"bert_span_labeler = modeling.models.BertSpanLabeler(network)" "bert_span_labeler = modeling.models.BertSpanLabeler(network)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "QpB9pgj4PpMg" "id": "QpB9pgj4PpMg"
}, },
"source": [ "source": [
...@@ -382,26 +361,20 @@ ...@@ -382,26 +361,20 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "RbqRNJCLJu4H" "id": "RbqRNJCLJu4H"
}, },
"outputs": [],
"source": [ "source": [
"tf.keras.utils.plot_model(bert_span_labeler, show_shapes=True, dpi=48)" "tf.keras.utils.plot_model(bert_span_labeler, show_shapes=True, dpi=48)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "fUf1vRxZJwio" "id": "fUf1vRxZJwio"
}, },
"outputs": [],
"source": [ "source": [
"# Create a set of 2-dimensional data tensors to feed into the model.\n", "# Create a set of 2-dimensional data tensors to feed into the model.\n",
"word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n", "word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
...@@ -412,12 +385,13 @@ ...@@ -412,12 +385,13 @@
"start_logits, end_logits = bert_span_labeler([word_id_data, mask_data, type_id_data])\n", "start_logits, end_logits = bert_span_labeler([word_id_data, mask_data, type_id_data])\n",
"print(start_logits)\n", "print(start_logits)\n",
"print(end_logits)" "print(end_logits)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "WqhgQaN1lt-G" "id": "WqhgQaN1lt-G"
}, },
"source": [ "source": [
...@@ -427,13 +401,9 @@ ...@@ -427,13 +401,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "waqs6azNl3Nn" "id": "waqs6azNl3Nn"
}, },
"outputs": [],
"source": [ "source": [
"start_positions = np.random.randint(sequence_length, size=(batch_size))\n", "start_positions = np.random.randint(sequence_length, size=(batch_size))\n",
"end_positions = np.random.randint(sequence_length, size=(batch_size))\n", "end_positions = np.random.randint(sequence_length, size=(batch_size))\n",
...@@ -445,12 +415,13 @@ ...@@ -445,12 +415,13 @@
"\n", "\n",
"total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2\n", "total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2\n",
"print(total_loss)" "print(total_loss)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "Zdf03YtZmd_d" "id": "Zdf03YtZmd_d"
}, },
"source": [ "source": [
...@@ -460,7 +431,6 @@ ...@@ -460,7 +431,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "0A1XnGSTChg9" "id": "0A1XnGSTChg9"
}, },
"source": [ "source": [
...@@ -472,38 +442,34 @@ ...@@ -472,38 +442,34 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "MSK8OpZgnQa9" "id": "MSK8OpZgnQa9"
}, },
"source": [ "source": [
"### Build a BertClassifier model wrapping TransformerEncoder\n", "### Build a BertClassifier model wrapping BertEncoder\n",
"\n", "\n",
"[BertClassifier](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_classifier.py) implements a [CLS] token classification model containing a single classification head." "[BertClassifier](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_classifier.py) implements a [CLS] token classification model containing a single classification head."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "cXXCsffkCphk" "id": "cXXCsffkCphk"
}, },
"outputs": [],
"source": [ "source": [
"network = modeling.networks.TransformerEncoder(\n", "network = modeling.networks.BertEncoder(\n",
" vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)\n", " vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)\n",
"\n", "\n",
"# Create a BERT trainer with the created network.\n", "# Create a BERT trainer with the created network.\n",
"num_classes = 2\n", "num_classes = 2\n",
"bert_classifier = modeling.models.BertClassifier(\n", "bert_classifier = modeling.models.BertClassifier(\n",
" network, num_classes=num_classes)" " network, num_classes=num_classes)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "8tZKueKYP4bB" "id": "8tZKueKYP4bB"
}, },
"source": [ "source": [
...@@ -512,26 +478,20 @@ ...@@ -512,26 +478,20 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "snlutm9ZJgEZ" "id": "snlutm9ZJgEZ"
}, },
"outputs": [],
"source": [ "source": [
"tf.keras.utils.plot_model(bert_classifier, show_shapes=True, dpi=48)" "tf.keras.utils.plot_model(bert_classifier, show_shapes=True, dpi=48)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "yyHPHsqBJkCz" "id": "yyHPHsqBJkCz"
}, },
"outputs": [],
"source": [ "source": [
"# Create a set of 2-dimensional data tensors to feed into the model.\n", "# Create a set of 2-dimensional data tensors to feed into the model.\n",
"word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n", "word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
...@@ -541,12 +501,13 @@ ...@@ -541,12 +501,13 @@
"# Feed the data to the model.\n", "# Feed the data to the model.\n",
"logits = bert_classifier([word_id_data, mask_data, type_id_data])\n", "logits = bert_classifier([word_id_data, mask_data, type_id_data])\n",
"print(logits)" "print(logits)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "w--a2mg4nzKm" "id": "w--a2mg4nzKm"
}, },
"source": [ "source": [
...@@ -557,45 +518,27 @@ ...@@ -557,45 +518,27 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "9X0S1DoFn_5Q" "id": "9X0S1DoFn_5Q"
}, },
"outputs": [],
"source": [ "source": [
"labels = np.random.randint(num_classes, size=(batch_size))\n", "labels = np.random.randint(num_classes, size=(batch_size))\n",
"\n", "\n",
"loss = modeling.losses.weighted_sparse_categorical_crossentropy_loss(\n", "loss = modeling.losses.weighted_sparse_categorical_crossentropy_loss(\n",
" labels=labels, predictions=tf.nn.log_softmax(logits, axis=-1))\n", " labels=labels, predictions=tf.nn.log_softmax(logits, axis=-1))\n",
"print(loss)" "print(loss)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "mzBqOylZo3og" "id": "mzBqOylZo3og"
}, },
"source": [ "source": [
"With the `loss`, you can optimize the model. Please see [run_classifier.py](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_classifier.py) or the colab [fine_tuning_bert.ipynb](https://github.com/tensorflow/models/blob/master/official/colab/fine_tuning_bert.ipynb) for the full example." "With the `loss`, you can optimize the model. Please see [run_classifier.py](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_classifier.py) or the colab [fine_tuning_bert.ipynb](https://github.com/tensorflow/models/blob/master/official/colab/fine_tuning_bert.ipynb) for the full example."
] ]
} }
], ]
"metadata": { }
"colab": { \ No newline at end of file
"collapsed_sections": [],
"name": "Introduction to the TensorFlow Models NLP library",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment