"### Install the TensorFlow Model Garden pip package\n",
"# Decoding API"
"\n",
"* `tf-models-official` is the stable Model Garden package. Note that it may not include the latest changes in the `tensorflow_models` github repo. To include latest changes, you may install `tf-models-nightly`,\n",
"which is the nightly Model Garden package created daily automatically.\n",
"* pip will install all models and dependencies automatically."
]
]
},
},
{
{
...
@@ -66,6 +62,30 @@
...
@@ -66,6 +62,30 @@
"\u003c/table\u003e"
"\u003c/table\u003e"
]
]
},
},
{
"cell_type": "markdown",
"metadata": {
"id": "fsACVQpVSifi"
},
"source": [
"### Install the TensorFlow Model Garden pip package\n",
"\n",
"* `tf-models-official` is the stable Model Garden package. Note that it may not include the latest changes in the `tensorflow_models` github repo. To include latest changes, you may install `tf-models-nightly`,\n",
"which is the nightly Model Garden package created daily automatically.\n",
"* pip will install all models and dependencies automatically."
"This API provides an interface to experiment with different decoding strategies used for auto-regressive models.\n",
"This API provides an interface to experiment with different decoding strategies used for auto-regressive models.\n",
"\n",
"\n",
"1. The following sampling strategies are provided in sampling_module.py, which inherits from the base Decoding class:\n",
"1. The following sampling strategies are provided in sampling_module.py, which inherits from the base Decoding class:\n",
...
@@ -182,7 +214,7 @@
...
@@ -182,7 +214,7 @@
"id": "lV1RRp6ihnGX"
"id": "lV1RRp6ihnGX"
},
},
"source": [
"source": [
"# Initialize the Model Hyper-parameters"
"## Initialize the Model Hyper-parameters"
]
]
},
},
{
{
...
@@ -193,44 +225,32 @@
...
@@ -193,44 +225,32 @@
},
},
"outputs": [],
"outputs": [],
"source": [
"source": [
"params = {}\n",
"params = {\n",
"params['num_heads'] = 2\n",
"'num_heads': 2\n",
"params['num_layers'] = 2\n",
"'num_layers': 2\n",
"params['batch_size'] = 2\n",
"'batch_size': 2\n",
"params['n_dims'] = 256\n",
"'n_dims': 256\n",
"params['max_decode_length'] = 4"
"'max_decode_length': 4}"
]
]
},
},
{
{
"cell_type": "markdown",
"cell_type": "markdown",
"metadata": {
"metadata": {
"id": "UGvmd0_dRFYI"
"id": "CYXkoplAij01"
},
},
"source": [
"source": [
"## What is a Cache?\n",
"## Initialize cache. "
"In auto-regressive architectures like Transformer based [Encoder-Decoder](https://arxiv.org/abs/1706.03762) models, \n",
"Cache is used for fast sequential decoding.\n",
"It is a nested dictionary storing pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) for every layer.\n",
" 'model_specific_item' : Model specific tensor shape,\n",
"}\n",
"\n",
"```"
]
]
},
},
{
{
"cell_type": "markdown",
"cell_type": "markdown",
"metadata": {
"metadata": {
"id": "CYXkoplAij01"
"id": "UGvmd0_dRFYI"
},
},
"source": [
"source": [
"# Initialize cache. "
"In auto-regressive architectures like Transformer based [Encoder-Decoder](https://arxiv.org/abs/1706.03762) models, \n",
"Cache is used for fast sequential decoding.\n",
"It is a nested dictionary storing pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) for every layer."
" In practice, this will be replaced by an actual model implementation such as [here](https://github.com/tensorflow/models/blob/master/official/nlp/transformer/transformer.py#L236)\n",
" In practice, this will be replaced by an actual model implementation such as [here](https://github.com/tensorflow/models/blob/master/official/nlp/transformer/transformer.py#L236)\n",
"```\n",
"```\n",
"Args:\n",
"Args:\n",
"i : Step that is being decoded.\n",
"i : Step that is being decoded.\n",
"Returns:\n",
"Returns:\n",
" logit probabilities of size [batch_size, 1, vocab_size]\n",
" logit probabilities of size [batch_size, 1, vocab_size]\n",
"```\n",
"```\n"
"\n"
]
]
},
},
{
{
...
@@ -307,15 +306,6 @@
...
@@ -307,15 +306,6 @@
" return probabilities[:, i, :]"
" return probabilities[:, i, :]"
]
]
},
},
{
"cell_type": "markdown",
"metadata": {
"id": "DBMUkaVmVZBg"
},
"source": [
"# Initialize symbols_to_logits_fn\n"
]
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": null,
"execution_count": null,
...
@@ -339,7 +329,7 @@
...
@@ -339,7 +329,7 @@
"id": "R_tV3jyWVL47"
"id": "R_tV3jyWVL47"
},
},
"source": [
"source": [
"# Greedy \n",
"## Greedy \n",
"Greedy decoding selects the token id with the highest probability as its next id: $id_t = argmax_{w}P(id | id_{1:t-1})$ at each timestep $t$. The following sketch shows greedy decoding. "
"Greedy decoding selects the token id with the highest probability as its next id: $id_t = argmax_{w}P(id | id_{1:t-1})$ at each timestep $t$. The following sketch shows greedy decoding. "
]
]
},
},
...
@@ -370,7 +360,7 @@
...
@@ -370,7 +360,7 @@
"id": "s4pTTsQXVz5O"
"id": "s4pTTsQXVz5O"
},
},
"source": [
"source": [
"# top_k sampling\n",
"## top_k sampling\n",
"In *Top-K* sampling, the *K* most likely next token ids are filtered and the probability mass is redistributed among only those *K* ids. "
"In *Top-K* sampling, the *K* most likely next token ids are filtered and the probability mass is redistributed among only those *K* ids. "
]
]
},
},
...
@@ -404,7 +394,7 @@
...
@@ -404,7 +394,7 @@
"id": "Jp3G-eE_WI4Y"
"id": "Jp3G-eE_WI4Y"
},
},
"source": [
"source": [
"# top_p sampling\n",
"## top_p sampling\n",
"Instead of sampling only from the most likely *K* token ids, in *Top-p* sampling chooses from the smallest possible set of ids whose cumulative probability exceeds the probability *p*."
"Instead of sampling only from the most likely *K* token ids, in *Top-p* sampling chooses from the smallest possible set of ids whose cumulative probability exceeds the probability *p*."
]
]
},
},
...
@@ -438,7 +428,7 @@
...
@@ -438,7 +428,7 @@
"id": "2hcuyJ2VWjDz"
"id": "2hcuyJ2VWjDz"
},
},
"source": [
"source": [
"# Beam search decoding\n",
"## Beam search decoding\n",
"Beam search reduces the risk of missing hidden high probability token ids by keeping the most likely num_beams of hypotheses at each time step and eventually choosing the hypothesis that has the overall highest probability. "
"Beam search reduces the risk of missing hidden high probability token ids by keeping the most likely num_beams of hypotheses at each time step and eventually choosing the hypothesis that has the overall highest probability. "
"### Build a `BertPretrainer` model wrapping `BertEncoder`\n",
"### Build a `BertPretrainer` model wrapping `BertEncoder`\n",
"\n",
"\n",
"The [BertEncoder](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/bert_encoder.py) implements the Transformer-based encoder as described in [BERT paper](https://arxiv.org/abs/1810.04805). It includes the embedding lookups and transformer layers, but not the masked language model or classification task networks.\n",
"The `nlp.networks.BertEncoder` class implements the Transformer-based encoder as described in [BERT paper](https://arxiv.org/abs/1810.04805). It includes the embedding lookups and transformer layers (`nlp.layers.TransformerEncoderBlock`), but not the masked language model or classification task networks.\n",
"\n",
"\n",
"The [BertPretrainer](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_pretrainer.py) allows a user to pass in a transformer stack, and instantiates the masked language model and classification networks that are used to create the training objectives."
"The `nlp.models.BertPretrainer` class allows a user to pass in a transformer stack, and instantiates the masked language model and classification networks that are used to create the training objectives."
" # The number of TransformerEncoderBlock layers\n",
" num_layers=3)"
]
]
},
},
{
{
...
@@ -177,7 +190,7 @@
...
@@ -177,7 +190,7 @@
"id": "0NH5irV5KTMS"
"id": "0NH5irV5KTMS"
},
},
"source": [
"source": [
"Inspecting the encoder, we see it contains few embedding layers, stacked `Transformer` layers and are connected to three input layers:\n",
"Inspecting the encoder, we see it contains few embedding layers, stacked `nlp.layers.TransformerEncoderBlock` layers and are connected to three input layers:\n",
"\n",
"\n",
"`input_word_ids`, `input_type_ids` and `input_mask`.\n"
"`input_word_ids`, `input_type_ids` and `input_mask`.\n"
"After training, we can save the weights of TransformerEncoder for the downstream fine-tuning tasks. Please see [run_pretraining.py](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_pretraining.py) for the full example.\n",
"After training, we can save the weights of TransformerEncoder for the downstream fine-tuning tasks. Please see [run_pretraining.py](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_pretraining.py) for the full example.\n"
"\n"
]
]
},
},
{
{
...
@@ -315,9 +330,9 @@
...
@@ -315,9 +330,9 @@
"source": [
"source": [
"### Build a BertSpanLabeler wrapping BertEncoder\n",
"### Build a BertSpanLabeler wrapping BertEncoder\n",
"\n",
"\n",
"[BertSpanLabeler](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_span_labeler.py) implements a simple single-span start-end predictor (that is, a model that predicts two values: a start token index and an end token index), suitable for SQuAD-style tasks.\n",
"The `nlp.models.BertSpanLabeler` class implements a simple single-span start-end predictor (that is, a model that predicts two values: a start token index and an end token index), suitable for SQuAD-style tasks.\n",
"\n",
"\n",
"Note that `BertSpanLabeler` wraps a `BertEncoder`, the weights of which can be restored from the above pretraining model.\n"
"Note that `nlp.models.BertSpanLabeler` wraps a `nlp.networks.BertEncoder`, the weights of which can be restored from the above pretraining model.\n"
"### Build a BertClassifier model wrapping BertEncoder\n",
"### Build a BertClassifier model wrapping BertEncoder\n",
"\n",
"\n",
"[BertClassifier](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_classifier.py) implements a [CLS] token classification model containing a single classification head."
"`nlp.models.BertClassifier` implements a [CLS] token classification model containing a single classification head."