Commit 828ed8b2 authored by Mark Daoust's avatar Mark Daoust Committed by A. Unique TensorFlower
Browse files

Fixup tensorflow_models.nlp tutorials

PiperOrigin-RevId: 443681252
parent 3f635b4d
...@@ -34,14 +34,10 @@ ...@@ -34,14 +34,10 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "fsACVQpVSifi" "id": "2X-XaMSVcLua"
}, },
"source": [ "source": [
"### Install the TensorFlow Model Garden pip package\n", "# Decoding API"
"\n",
"* `tf-models-official` is the stable Model Garden package. Note that it may not include the latest changes in the `tensorflow_models` github repo. To include latest changes, you may install `tf-models-nightly`,\n",
"which is the nightly Model Garden package created daily automatically.\n",
"* pip will install all models and dependencies automatically."
] ]
}, },
{ {
...@@ -66,6 +62,30 @@ ...@@ -66,6 +62,30 @@
"\u003c/table\u003e" "\u003c/table\u003e"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {
"id": "fsACVQpVSifi"
},
"source": [
"### Install the TensorFlow Model Garden pip package\n",
"\n",
"* `tf-models-official` is the stable Model Garden package. Note that it may not include the latest changes in the `tensorflow_models` github repo. To include latest changes, you may install `tf-models-nightly`,\n",
"which is the nightly Model Garden package created daily automatically.\n",
"* pip will install all models and dependencies automatically."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "G4BhAu01HZcM"
},
"outputs": [],
"source": [
"!pip uninstall -y opencv-python"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
...@@ -74,7 +94,7 @@ ...@@ -74,7 +94,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"pip install tf-models-nightly" "!pip install tf-models-official"
] ]
}, },
{ {
...@@ -92,9 +112,20 @@ ...@@ -92,9 +112,20 @@
"\n", "\n",
"import tensorflow as tf\n", "import tensorflow as tf\n",
"\n", "\n",
"from official import nlp\n", "from tensorflow_models import nlp"
"from official.nlp.modeling.ops import sampling_module\n", ]
"from official.nlp.modeling.ops import beam_search" },
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "T92ccAzlnGqh"
},
"outputs": [],
"source": [
"def length_norm(length, dtype):\n",
" \"\"\"Return length normalization factor.\"\"\"\n",
" return tf.pow(((5. + tf.cast(length, dtype)) / 6.), 0.0)"
] ]
}, },
{ {
...@@ -103,7 +134,8 @@ ...@@ -103,7 +134,8 @@
"id": "0AWgyo-IQ5sP" "id": "0AWgyo-IQ5sP"
}, },
"source": [ "source": [
"# Decoding API\n", "## Overview\n",
"\n",
"This API provides an interface to experiment with different decoding strategies used for auto-regressive models.\n", "This API provides an interface to experiment with different decoding strategies used for auto-regressive models.\n",
"\n", "\n",
"1. The following sampling strategies are provided in sampling_module.py, which inherits from the base Decoding class:\n", "1. The following sampling strategies are provided in sampling_module.py, which inherits from the base Decoding class:\n",
...@@ -182,7 +214,7 @@ ...@@ -182,7 +214,7 @@
"id": "lV1RRp6ihnGX" "id": "lV1RRp6ihnGX"
}, },
"source": [ "source": [
"# Initialize the Model Hyper-parameters" "## Initialize the Model Hyper-parameters"
] ]
}, },
{ {
...@@ -193,44 +225,32 @@ ...@@ -193,44 +225,32 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"params = {}\n", "params = {\n",
"params['num_heads'] = 2\n", " 'num_heads': 2\n",
"params['num_layers'] = 2\n", " 'num_layers': 2\n",
"params['batch_size'] = 2\n", " 'batch_size': 2\n",
"params['n_dims'] = 256\n", " 'n_dims': 256\n",
"params['max_decode_length'] = 4" " 'max_decode_length': 4}"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "UGvmd0_dRFYI" "id": "CYXkoplAij01"
}, },
"source": [ "source": [
"## What is a Cache?\n", "## Initialize cache. "
"In auto-regressive architectures like Transformer based [Encoder-Decoder](https://arxiv.org/abs/1706.03762) models, \n",
"Cache is used for fast sequential decoding.\n",
"It is a nested dictionary storing pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) for every layer.\n",
"\n",
"```\n",
"{\n",
" 'layer_%d' % layer: {\n",
" 'k': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims']/params['num_heads']], dtype=tf.float32),\n",
" 'v': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims']/params['num_heads']], dtype=tf.float32)\n",
" } for layer in range(params['num_layers']),\n",
" 'model_specific_item' : Model specific tensor shape,\n",
"}\n",
"\n",
"```"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "CYXkoplAij01" "id": "UGvmd0_dRFYI"
}, },
"source": [ "source": [
"# Initialize cache. " "In auto-regressive architectures like Transformer based [Encoder-Decoder](https://arxiv.org/abs/1706.03762) models, \n",
"Cache is used for fast sequential decoding.\n",
"It is a nested dictionary storing pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) for every layer."
] ]
}, },
{ {
...@@ -243,35 +263,15 @@ ...@@ -243,35 +263,15 @@
"source": [ "source": [
"cache = {\n", "cache = {\n",
" 'layer_%d' % layer: {\n", " 'layer_%d' % layer: {\n",
" 'k': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims']/params['num_heads']], dtype=tf.float32),\n", " 'k': tf.zeros(\n",
" 'v': tf.zeros([params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims']/params['num_heads']], dtype=tf.float32)\n", " shape=[params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims'] // params['num_heads']],\n",
" dtype=tf.float32),\n",
" 'v': tf.zeros(\n",
" shape=[params['batch_size'], params['max_decode_length'], params['num_heads'], params['n_dims'] // params['num_heads']],\n",
" dtype=tf.float32)\n",
" } for layer in range(params['num_layers'])\n", " } for layer in range(params['num_layers'])\n",
" }\n", " }\n",
"print(\"cache key shape for layer 1 :\", cache['layer_1']['k'].shape)" "print(\"cache value shape for layer 1 :\", cache['layer_1']['k'].shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nNY3Xn8SiblP"
},
"source": [
"# Define closure for length normalization. **optional.**\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "T92ccAzlnGqh"
},
"outputs": [],
"source": [
"def length_norm(length, dtype):\n",
" \"\"\"Return length normalization factor.\"\"\"\n",
" return tf.pow(((5. + tf.cast(length, dtype)) / 6.), 0.0)"
] ]
}, },
{ {
...@@ -280,15 +280,14 @@ ...@@ -280,15 +280,14 @@
"id": "syl7I5nURPgW" "id": "syl7I5nURPgW"
}, },
"source": [ "source": [
"# Create model_fn\n", "### Create model_fn\n",
" In practice, this will be replaced by an actual model implementation such as [here](https://github.com/tensorflow/models/blob/master/official/nlp/transformer/transformer.py#L236)\n", " In practice, this will be replaced by an actual model implementation such as [here](https://github.com/tensorflow/models/blob/master/official/nlp/transformer/transformer.py#L236)\n",
"```\n", "```\n",
"Args:\n", "Args:\n",
"i : Step that is being decoded.\n", "i : Step that is being decoded.\n",
"Returns:\n", "Returns:\n",
" logit probabilities of size [batch_size, 1, vocab_size]\n", " logit probabilities of size [batch_size, 1, vocab_size]\n",
"```\n", "```\n"
"\n"
] ]
}, },
{ {
...@@ -307,15 +306,6 @@ ...@@ -307,15 +306,6 @@
" return probabilities[:, i, :]" " return probabilities[:, i, :]"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {
"id": "DBMUkaVmVZBg"
},
"source": [
"# Initialize symbols_to_logits_fn\n"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
...@@ -339,7 +329,7 @@ ...@@ -339,7 +329,7 @@
"id": "R_tV3jyWVL47" "id": "R_tV3jyWVL47"
}, },
"source": [ "source": [
"# Greedy \n", "## Greedy \n",
"Greedy decoding selects the token id with the highest probability as its next id: $id_t = argmax_{w}P(id | id_{1:t-1})$ at each timestep $t$. The following sketch shows greedy decoding. " "Greedy decoding selects the token id with the highest probability as its next id: $id_t = argmax_{w}P(id | id_{1:t-1})$ at each timestep $t$. The following sketch shows greedy decoding. "
] ]
}, },
...@@ -370,7 +360,7 @@ ...@@ -370,7 +360,7 @@
"id": "s4pTTsQXVz5O" "id": "s4pTTsQXVz5O"
}, },
"source": [ "source": [
"# top_k sampling\n", "## top_k sampling\n",
"In *Top-K* sampling, the *K* most likely next token ids are filtered and the probability mass is redistributed among only those *K* ids. " "In *Top-K* sampling, the *K* most likely next token ids are filtered and the probability mass is redistributed among only those *K* ids. "
] ]
}, },
...@@ -404,7 +394,7 @@ ...@@ -404,7 +394,7 @@
"id": "Jp3G-eE_WI4Y" "id": "Jp3G-eE_WI4Y"
}, },
"source": [ "source": [
"# top_p sampling\n", "## top_p sampling\n",
"Instead of sampling only from the most likely *K* token ids, in *Top-p* sampling chooses from the smallest possible set of ids whose cumulative probability exceeds the probability *p*." "Instead of sampling only from the most likely *K* token ids, in *Top-p* sampling chooses from the smallest possible set of ids whose cumulative probability exceeds the probability *p*."
] ]
}, },
...@@ -438,7 +428,7 @@ ...@@ -438,7 +428,7 @@
"id": "2hcuyJ2VWjDz" "id": "2hcuyJ2VWjDz"
}, },
"source": [ "source": [
"# Beam search decoding\n", "## Beam search decoding\n",
"Beam search reduces the risk of missing hidden high probability token ids by keeping the most likely num_beams of hypotheses at each time step and eventually choosing the hypothesis that has the overall highest probability. " "Beam search reduces the risk of missing hidden high probability token ids by keeping the most likely num_beams of hypotheses at each time step and eventually choosing the hypothesis that has the overall highest probability. "
] ]
}, },
......
This diff is collapsed.
...@@ -95,6 +95,19 @@ ...@@ -95,6 +95,19 @@
"* `pip` will install all models and dependencies automatically." "* `pip` will install all models and dependencies automatically."
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IAOmYthAzI7J"
},
"outputs": [],
"source": [
"# Uninstall colab's opencv-python, it conflicts with `opencv-python-headless`\n",
"# which is installed by tf-models-official\n",
"!pip uninstall -y opencv-python"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
...@@ -103,7 +116,7 @@ ...@@ -103,7 +116,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"!pip install -q tf-models-official==2.4.0" "!pip install tf-models-official"
] ]
}, },
{ {
...@@ -126,8 +139,7 @@ ...@@ -126,8 +139,7 @@
"import numpy as np\n", "import numpy as np\n",
"import tensorflow as tf\n", "import tensorflow as tf\n",
"\n", "\n",
"from official.nlp import modeling\n", "from tensorflow_models import nlp"
"from official.nlp.modeling import layers, losses, models, networks"
] ]
}, },
{ {
...@@ -151,9 +163,9 @@ ...@@ -151,9 +163,9 @@
"source": [ "source": [
"### Build a `BertPretrainer` model wrapping `BertEncoder`\n", "### Build a `BertPretrainer` model wrapping `BertEncoder`\n",
"\n", "\n",
"The [BertEncoder](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/bert_encoder.py) implements the Transformer-based encoder as described in [BERT paper](https://arxiv.org/abs/1810.04805). It includes the embedding lookups and transformer layers, but not the masked language model or classification task networks.\n", "The `nlp.networks.BertEncoder` class implements the Transformer-based encoder as described in [BERT paper](https://arxiv.org/abs/1810.04805). It includes the embedding lookups and transformer layers (`nlp.layers.TransformerEncoderBlock`), but not the masked language model or classification task networks.\n",
"\n", "\n",
"The [BertPretrainer](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_pretrainer.py) allows a user to pass in a transformer stack, and instantiates the masked language model and classification networks that are used to create the training objectives." "The `nlp.models.BertPretrainer` class allows a user to pass in a transformer stack, and instantiates the masked language model and classification networks that are used to create the training objectives."
] ]
}, },
{ {
...@@ -166,9 +178,10 @@ ...@@ -166,9 +178,10 @@
"source": [ "source": [
"# Build a small transformer network.\n", "# Build a small transformer network.\n",
"vocab_size = 100\n", "vocab_size = 100\n",
"sequence_length = 16\n", "network = nlp.networks.BertEncoder(\n",
"network = modeling.networks.BertEncoder(\n", " vocab_size=vocab_size, \n",
" vocab_size=vocab_size, num_layers=2, sequence_length=16)" " # The number of TransformerEncoderBlock layers\n",
" num_layers=3)"
] ]
}, },
{ {
...@@ -177,7 +190,7 @@ ...@@ -177,7 +190,7 @@
"id": "0NH5irV5KTMS" "id": "0NH5irV5KTMS"
}, },
"source": [ "source": [
"Inspecting the encoder, we see it contains few embedding layers, stacked `Transformer` layers and are connected to three input layers:\n", "Inspecting the encoder, we see it contains few embedding layers, stacked `nlp.layers.TransformerEncoderBlock` layers and are connected to three input layers:\n",
"\n", "\n",
"`input_word_ids`, `input_type_ids` and `input_mask`.\n" "`input_word_ids`, `input_type_ids` and `input_mask`.\n"
] ]
...@@ -190,7 +203,7 @@ ...@@ -190,7 +203,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"tf.keras.utils.plot_model(network, show_shapes=True, dpi=48)" "tf.keras.utils.plot_model(network, show_shapes=True, expand_nested=True, dpi=48)"
] ]
}, },
{ {
...@@ -203,7 +216,7 @@ ...@@ -203,7 +216,7 @@
"source": [ "source": [
"# Create a BERT pretrainer with the created network.\n", "# Create a BERT pretrainer with the created network.\n",
"num_token_predictions = 8\n", "num_token_predictions = 8\n",
"bert_pretrainer = modeling.models.BertPretrainer(\n", "bert_pretrainer = nlp.models.BertPretrainer(\n",
" network, num_classes=2, num_token_predictions=num_token_predictions, output='predictions')" " network, num_classes=2, num_token_predictions=num_token_predictions, output='predictions')"
] ]
}, },
...@@ -213,7 +226,7 @@ ...@@ -213,7 +226,7 @@
"id": "d5h5HT7gNHx_" "id": "d5h5HT7gNHx_"
}, },
"source": [ "source": [
"Inspecting the `bert_pretrainer`, we see it wraps the `encoder` with additional `MaskedLM` and `Classification` heads." "Inspecting the `bert_pretrainer`, we see it wraps the `encoder` with additional `MaskedLM` and `nlp.layers.ClassificationHead` heads."
] ]
}, },
{ {
...@@ -224,7 +237,7 @@ ...@@ -224,7 +237,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"tf.keras.utils.plot_model(bert_pretrainer, show_shapes=True, dpi=48)" "tf.keras.utils.plot_model(bert_pretrainer, show_shapes=True, expand_nested=True, dpi=48)"
] ]
}, },
{ {
...@@ -236,7 +249,9 @@ ...@@ -236,7 +249,9 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# We can feed some dummy data to get masked language model and sentence output.\n", "# We can feed some dummy data to get masked language model and sentence output.\n",
"sequence_length = 16\n",
"batch_size = 2\n", "batch_size = 2\n",
"\n",
"word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n", "word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
"mask_data = np.random.randint(2, size=(batch_size, sequence_length))\n", "mask_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
"type_id_data = np.random.randint(2, size=(batch_size, sequence_length))\n", "type_id_data = np.random.randint(2, size=(batch_size, sequence_length))\n",
...@@ -246,8 +261,8 @@ ...@@ -246,8 +261,8 @@
" [word_id_data, mask_data, type_id_data, masked_lm_positions_data])\n", " [word_id_data, mask_data, type_id_data, masked_lm_positions_data])\n",
"lm_output = outputs[\"masked_lm\"]\n", "lm_output = outputs[\"masked_lm\"]\n",
"sentence_output = outputs[\"classification\"]\n", "sentence_output = outputs[\"classification\"]\n",
"print(lm_output)\n", "print(f'lm_output: shape={lm_output.shape}, dtype={lm_output.dtype!r}')\n",
"print(sentence_output)" "print(f'sentence_output: shape={sentence_output.shape}, dtype={sentence_output.dtype!r}')"
] ]
}, },
{ {
...@@ -272,14 +287,15 @@ ...@@ -272,14 +287,15 @@
"masked_lm_weights_data = np.random.randint(2, size=(batch_size, num_token_predictions))\n", "masked_lm_weights_data = np.random.randint(2, size=(batch_size, num_token_predictions))\n",
"next_sentence_labels_data = np.random.randint(2, size=(batch_size))\n", "next_sentence_labels_data = np.random.randint(2, size=(batch_size))\n",
"\n", "\n",
"mlm_loss = modeling.losses.weighted_sparse_categorical_crossentropy_loss(\n", "mlm_loss = nlp.losses.weighted_sparse_categorical_crossentropy_loss(\n",
" labels=masked_lm_ids_data,\n", " labels=masked_lm_ids_data,\n",
" predictions=lm_output,\n", " predictions=lm_output,\n",
" weights=masked_lm_weights_data)\n", " weights=masked_lm_weights_data)\n",
"sentence_loss = modeling.losses.weighted_sparse_categorical_crossentropy_loss(\n", "sentence_loss = nlp.losses.weighted_sparse_categorical_crossentropy_loss(\n",
" labels=next_sentence_labels_data,\n", " labels=next_sentence_labels_data,\n",
" predictions=sentence_output)\n", " predictions=sentence_output)\n",
"loss = mlm_loss + sentence_loss\n", "loss = mlm_loss + sentence_loss\n",
"\n",
"print(loss)" "print(loss)"
] ]
}, },
...@@ -290,8 +306,7 @@ ...@@ -290,8 +306,7 @@
}, },
"source": [ "source": [
"With the loss, you can optimize the model.\n", "With the loss, you can optimize the model.\n",
"After training, we can save the weights of TransformerEncoder for the downstream fine-tuning tasks. Please see [run_pretraining.py](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_pretraining.py) for the full example.\n", "After training, we can save the weights of TransformerEncoder for the downstream fine-tuning tasks. Please see [run_pretraining.py](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_pretraining.py) for the full example.\n"
"\n"
] ]
}, },
{ {
...@@ -315,9 +330,9 @@ ...@@ -315,9 +330,9 @@
"source": [ "source": [
"### Build a BertSpanLabeler wrapping BertEncoder\n", "### Build a BertSpanLabeler wrapping BertEncoder\n",
"\n", "\n",
"[BertSpanLabeler](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_span_labeler.py) implements a simple single-span start-end predictor (that is, a model that predicts two values: a start token index and an end token index), suitable for SQuAD-style tasks.\n", "The `nlp.models.BertSpanLabeler` class implements a simple single-span start-end predictor (that is, a model that predicts two values: a start token index and an end token index), suitable for SQuAD-style tasks.\n",
"\n", "\n",
"Note that `BertSpanLabeler` wraps a `BertEncoder`, the weights of which can be restored from the above pretraining model.\n" "Note that `nlp.models.BertSpanLabeler` wraps a `nlp.networks.BertEncoder`, the weights of which can be restored from the above pretraining model.\n"
] ]
}, },
{ {
...@@ -328,11 +343,11 @@ ...@@ -328,11 +343,11 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"network = modeling.networks.BertEncoder(\n", "network = nlp.networks.BertEncoder(\n",
" vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)\n", " vocab_size=vocab_size, num_layers=2)\n",
"\n", "\n",
"# Create a BERT trainer with the created network.\n", "# Create a BERT trainer with the created network.\n",
"bert_span_labeler = modeling.models.BertSpanLabeler(network)" "bert_span_labeler = nlp.models.BertSpanLabeler(network)"
] ]
}, },
{ {
...@@ -352,7 +367,7 @@ ...@@ -352,7 +367,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"tf.keras.utils.plot_model(bert_span_labeler, show_shapes=True, dpi=48)" "tf.keras.utils.plot_model(bert_span_labeler, show_shapes=True, expand_nested=True, dpi=48)"
] ]
}, },
{ {
...@@ -370,8 +385,9 @@ ...@@ -370,8 +385,9 @@
"\n", "\n",
"# Feed the data to the model.\n", "# Feed the data to the model.\n",
"start_logits, end_logits = bert_span_labeler([word_id_data, mask_data, type_id_data])\n", "start_logits, end_logits = bert_span_labeler([word_id_data, mask_data, type_id_data])\n",
"print(start_logits)\n", "\n",
"print(end_logits)" "print(f'start_logits: shape={start_logits.shape}, dtype={start_logits.dtype!r}')\n",
"print(f'end_logits: shape={end_logits.shape}, dtype={end_logits.dtype!r}')"
] ]
}, },
{ {
...@@ -432,7 +448,7 @@ ...@@ -432,7 +448,7 @@
"source": [ "source": [
"### Build a BertClassifier model wrapping BertEncoder\n", "### Build a BertClassifier model wrapping BertEncoder\n",
"\n", "\n",
"[BertClassifier](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_classifier.py) implements a [CLS] token classification model containing a single classification head." "`nlp.models.BertClassifier` implements a [CLS] token classification model containing a single classification head."
] ]
}, },
{ {
...@@ -443,12 +459,12 @@ ...@@ -443,12 +459,12 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"network = modeling.networks.BertEncoder(\n", "network = nlp.networks.BertEncoder(\n",
" vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)\n", " vocab_size=vocab_size, num_layers=2)\n",
"\n", "\n",
"# Create a BERT trainer with the created network.\n", "# Create a BERT trainer with the created network.\n",
"num_classes = 2\n", "num_classes = 2\n",
"bert_classifier = modeling.models.BertClassifier(\n", "bert_classifier = nlp.models.BertClassifier(\n",
" network, num_classes=num_classes)" " network, num_classes=num_classes)"
] ]
}, },
...@@ -469,7 +485,7 @@ ...@@ -469,7 +485,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"tf.keras.utils.plot_model(bert_classifier, show_shapes=True, dpi=48)" "tf.keras.utils.plot_model(bert_classifier, show_shapes=True, expand_nested=True, dpi=48)"
] ]
}, },
{ {
...@@ -487,7 +503,7 @@ ...@@ -487,7 +503,7 @@
"\n", "\n",
"# Feed the data to the model.\n", "# Feed the data to the model.\n",
"logits = bert_classifier([word_id_data, mask_data, type_id_data])\n", "logits = bert_classifier([word_id_data, mask_data, type_id_data])\n",
"print(logits)" "print(f'logits: shape={logits.shape}, dtype={logits.dtype!r}')"
] ]
}, },
{ {
...@@ -529,8 +545,7 @@ ...@@ -529,8 +545,7 @@
"metadata": { "metadata": {
"colab": { "colab": {
"collapsed_sections": [], "collapsed_sections": [],
"name": "Introduction to the TensorFlow Models NLP library", "name": "nlp_modeling_library_intro.ipynb",
"private_outputs": true,
"provenance": [], "provenance": [],
"toc_visible": true "toc_visible": true
}, },
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment