Commit e356598a authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Replace modeling.losses.weighted_sparse_categorical_crossentropy_loss with...

Replace modeling.losses.weighted_sparse_categorical_crossentropy_loss with tf.keras.losses.sparse_categorical_crossentropy

PiperOrigin-RevId: 360549819
parent 14b7ac52
{ {
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Introduction to the TensorFlow Models NLP library",
"private_outputs": true,
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -26,10 +11,12 @@ ...@@ -26,10 +11,12 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"cellView": "form", "cellView": "form",
"id": "8nvTnfs6Q692" "id": "8nvTnfs6Q692"
}, },
"outputs": [],
"source": [ "source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n", "# you may not use this file except in compliance with the License.\n",
...@@ -42,9 +29,7 @@ ...@@ -42,9 +29,7 @@
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n", "# See the License for the specific language governing permissions and\n",
"# limitations under the License." "# limitations under the License."
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -61,20 +46,20 @@ ...@@ -61,20 +46,20 @@
"id": "cH-oJ8R6AHMK" "id": "cH-oJ8R6AHMK"
}, },
"source": [ "source": [
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n", "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" <td>\n", " \u003ctd\u003e\n",
" <a target=\"_blank\" href=\"https://www.tensorflow.org/official_models/nlp/nlp_modeling_library_intro\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n", " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/official_models/nlp/nlp_modeling_library_intro\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" </td>\n", " \u003c/td\u003e\n",
" <td>\n", " \u003ctd\u003e\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n", " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" </td>\n", " \u003c/td\u003e\n",
" <td>\n", " \u003ctd\u003e\n",
" <a target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n", " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" </td>\n", " \u003c/td\u003e\n",
" <td>\n", " \u003ctd\u003e\n",
" <a href=\"https://storage.googleapis.com/tensorflow_docs/models/official/colab/nlp/nlp_modeling_library_intro.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n", " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/official/colab/nlp/nlp_modeling_library_intro.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" </td>\n", " \u003c/td\u003e\n",
"</table>" "\u003c/table\u003e"
] ]
}, },
{ {
...@@ -112,14 +97,14 @@ ...@@ -112,14 +97,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "Y-qGkdh6_sZc" "id": "Y-qGkdh6_sZc"
}, },
"outputs": [],
"source": [ "source": [
"!pip install -q tf-models-official==2.4.0" "!pip install -q tf-models-official==2.4.0"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -132,18 +117,18 @@ ...@@ -132,18 +117,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "jqYXqtjBAJd9" "id": "jqYXqtjBAJd9"
}, },
"outputs": [],
"source": [ "source": [
"import numpy as np\n", "import numpy as np\n",
"import tensorflow as tf\n", "import tensorflow as tf\n",
"\n", "\n",
"from official.nlp import modeling\n", "from official.nlp import modeling\n",
"from official.nlp.modeling import layers, losses, models, networks" "from official.nlp.modeling import layers, losses, models, networks"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -173,18 +158,18 @@ ...@@ -173,18 +158,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "EXkcXz-9BwB3" "id": "EXkcXz-9BwB3"
}, },
"outputs": [],
"source": [ "source": [
"# Build a small transformer network.\n", "# Build a small transformer network.\n",
"vocab_size = 100\n", "vocab_size = 100\n",
"sequence_length = 16\n", "sequence_length = 16\n",
"network = modeling.networks.BertEncoder(\n", "network = modeling.networks.BertEncoder(\n",
" vocab_size=vocab_size, num_layers=2, sequence_length=16)" " vocab_size=vocab_size, num_layers=2, sequence_length=16)"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -199,28 +184,28 @@ ...@@ -199,28 +184,28 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "lZNoZkBrIoff" "id": "lZNoZkBrIoff"
}, },
"outputs": [],
"source": [ "source": [
"tf.keras.utils.plot_model(network, show_shapes=True, dpi=48)" "tf.keras.utils.plot_model(network, show_shapes=True, dpi=48)"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "o7eFOZXiIl-b" "id": "o7eFOZXiIl-b"
}, },
"outputs": [],
"source": [ "source": [
"# Create a BERT pretrainer with the created network.\n", "# Create a BERT pretrainer with the created network.\n",
"num_token_predictions = 8\n", "num_token_predictions = 8\n",
"bert_pretrainer = modeling.models.BertPretrainer(\n", "bert_pretrainer = modeling.models.BertPretrainer(\n",
" network, num_classes=2, num_token_predictions=num_token_predictions, output='predictions')" " network, num_classes=2, num_token_predictions=num_token_predictions, output='predictions')"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -233,20 +218,22 @@ ...@@ -233,20 +218,22 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "2tcNfm03IBF7" "id": "2tcNfm03IBF7"
}, },
"outputs": [],
"source": [ "source": [
"tf.keras.utils.plot_model(bert_pretrainer, show_shapes=True, dpi=48)" "tf.keras.utils.plot_model(bert_pretrainer, show_shapes=True, dpi=48)"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "F2oHrXGUIS0M" "id": "F2oHrXGUIS0M"
}, },
"outputs": [],
"source": [ "source": [
"# We can feed some dummy data to get masked language model and sentence output.\n", "# We can feed some dummy data to get masked language model and sentence output.\n",
"batch_size = 2\n", "batch_size = 2\n",
...@@ -261,9 +248,7 @@ ...@@ -261,9 +248,7 @@
"sentence_output = outputs[\"classification\"]\n", "sentence_output = outputs[\"classification\"]\n",
"print(lm_output)\n", "print(lm_output)\n",
"print(sentence_output)" "print(sentence_output)"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -277,9 +262,11 @@ ...@@ -277,9 +262,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "k30H4Q86f52x" "id": "k30H4Q86f52x"
}, },
"outputs": [],
"source": [ "source": [
"masked_lm_ids_data = np.random.randint(vocab_size, size=(batch_size, num_token_predictions))\n", "masked_lm_ids_data = np.random.randint(vocab_size, size=(batch_size, num_token_predictions))\n",
"masked_lm_weights_data = np.random.randint(2, size=(batch_size, num_token_predictions))\n", "masked_lm_weights_data = np.random.randint(2, size=(batch_size, num_token_predictions))\n",
...@@ -294,9 +281,7 @@ ...@@ -294,9 +281,7 @@
" predictions=sentence_output)\n", " predictions=sentence_output)\n",
"loss = mlm_loss + sentence_loss\n", "loss = mlm_loss + sentence_loss\n",
"print(loss)" "print(loss)"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -337,18 +322,18 @@ ...@@ -337,18 +322,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "B941M4iUCejO" "id": "B941M4iUCejO"
}, },
"outputs": [],
"source": [ "source": [
"network = modeling.networks.BertEncoder(\n", "network = modeling.networks.BertEncoder(\n",
" vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)\n", " vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)\n",
"\n", "\n",
"# Create a BERT trainer with the created network.\n", "# Create a BERT trainer with the created network.\n",
"bert_span_labeler = modeling.models.BertSpanLabeler(network)" "bert_span_labeler = modeling.models.BertSpanLabeler(network)"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -361,20 +346,22 @@ ...@@ -361,20 +346,22 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "RbqRNJCLJu4H" "id": "RbqRNJCLJu4H"
}, },
"outputs": [],
"source": [ "source": [
"tf.keras.utils.plot_model(bert_span_labeler, show_shapes=True, dpi=48)" "tf.keras.utils.plot_model(bert_span_labeler, show_shapes=True, dpi=48)"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "fUf1vRxZJwio" "id": "fUf1vRxZJwio"
}, },
"outputs": [],
"source": [ "source": [
"# Create a set of 2-dimensional data tensors to feed into the model.\n", "# Create a set of 2-dimensional data tensors to feed into the model.\n",
"word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n", "word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
...@@ -385,9 +372,7 @@ ...@@ -385,9 +372,7 @@
"start_logits, end_logits = bert_span_labeler([word_id_data, mask_data, type_id_data])\n", "start_logits, end_logits = bert_span_labeler([word_id_data, mask_data, type_id_data])\n",
"print(start_logits)\n", "print(start_logits)\n",
"print(end_logits)" "print(end_logits)"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -401,9 +386,11 @@ ...@@ -401,9 +386,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "waqs6azNl3Nn" "id": "waqs6azNl3Nn"
}, },
"outputs": [],
"source": [ "source": [
"start_positions = np.random.randint(sequence_length, size=(batch_size))\n", "start_positions = np.random.randint(sequence_length, size=(batch_size))\n",
"end_positions = np.random.randint(sequence_length, size=(batch_size))\n", "end_positions = np.random.randint(sequence_length, size=(batch_size))\n",
...@@ -415,9 +402,7 @@ ...@@ -415,9 +402,7 @@
"\n", "\n",
"total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2\n", "total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2\n",
"print(total_loss)" "print(total_loss)"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -452,9 +437,11 @@ ...@@ -452,9 +437,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "cXXCsffkCphk" "id": "cXXCsffkCphk"
}, },
"outputs": [],
"source": [ "source": [
"network = modeling.networks.BertEncoder(\n", "network = modeling.networks.BertEncoder(\n",
" vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)\n", " vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)\n",
...@@ -463,9 +450,7 @@ ...@@ -463,9 +450,7 @@
"num_classes = 2\n", "num_classes = 2\n",
"bert_classifier = modeling.models.BertClassifier(\n", "bert_classifier = modeling.models.BertClassifier(\n",
" network, num_classes=num_classes)" " network, num_classes=num_classes)"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -478,20 +463,22 @@ ...@@ -478,20 +463,22 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "snlutm9ZJgEZ" "id": "snlutm9ZJgEZ"
}, },
"outputs": [],
"source": [ "source": [
"tf.keras.utils.plot_model(bert_classifier, show_shapes=True, dpi=48)" "tf.keras.utils.plot_model(bert_classifier, show_shapes=True, dpi=48)"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "yyHPHsqBJkCz" "id": "yyHPHsqBJkCz"
}, },
"outputs": [],
"source": [ "source": [
"# Create a set of 2-dimensional data tensors to feed into the model.\n", "# Create a set of 2-dimensional data tensors to feed into the model.\n",
"word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n", "word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
...@@ -501,9 +488,7 @@ ...@@ -501,9 +488,7 @@
"# Feed the data to the model.\n", "# Feed the data to the model.\n",
"logits = bert_classifier([word_id_data, mask_data, type_id_data])\n", "logits = bert_classifier([word_id_data, mask_data, type_id_data])\n",
"print(logits)" "print(logits)"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -518,18 +503,18 @@ ...@@ -518,18 +503,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"id": "9X0S1DoFn_5Q" "id": "9X0S1DoFn_5Q"
}, },
"outputs": [],
"source": [ "source": [
"labels = np.random.randint(num_classes, size=(batch_size))\n", "labels = np.random.randint(num_classes, size=(batch_size))\n",
"\n", "\n",
"loss = modeling.losses.weighted_sparse_categorical_crossentropy_loss(\n", "loss = tf.keras.losses.sparse_categorical_crossentropy(\n",
" labels=labels, predictions=tf.nn.log_softmax(logits, axis=-1))\n", " labels, logits, from_logits=True)\n",
"print(loss)" "print(loss)"
], ]
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
...@@ -540,5 +525,20 @@ ...@@ -540,5 +525,20 @@
"With the `loss`, you can optimize the model. Please see [run_classifier.py](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_classifier.py) or the colab [fine_tuning_bert.ipynb](https://github.com/tensorflow/models/blob/master/official/colab/fine_tuning_bert.ipynb) for the full example." "With the `loss`, you can optimize the model. Please see [run_classifier.py](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_classifier.py) or the colab [fine_tuning_bert.ipynb](https://github.com/tensorflow/models/blob/master/official/colab/fine_tuning_bert.ipynb) for the full example."
] ]
} }
] ],
} "metadata": {
\ No newline at end of file "colab": {
"collapsed_sections": [],
"name": "Introduction to the TensorFlow Models NLP library",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment