"docs/vscode:/vscode.git/clone" did not exist on "3065ad595bc842b5c4d3895fdc96db8c1714c962"
Commit 76145d74 authored by Jeremiah Liu's avatar Jeremiah Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 370096487
parent 2eaf8f8a
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "vs3a5tGVAWGI"
},
"source": [
"##### Copyright 2021 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "HYfsarcYBJQp"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aOpqCFEyBQDd"
},
"source": [
"# Uncertainty-aware Deep Language Learning with BERT-SNGP"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6MlSYP6cBT61"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/official_models/tutorials/uncertainty_quantification_with_sngp_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/colab/uncertainty_quantification_with_sngp_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/colab/uncertainty_quantification_with_sngp_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/docs/models/official/colab/fine_tuning_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/hub_logo_32px.png\" /\u003eSee TF Hub model\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-IM5IzM26GBh"
},
"source": [
"In the [SNGP tutorial](https://www.tensorflow.org/tutorials/uncertainty/sngp), you learned how to build SNGP model on top of a deep residual network to improve its ability to quantify its uncertainty. In this tutorial, you will apply SNGP to a natural language understanding (NLU) task by building it on top of a deep BERT encoder to improve deep NLU model's ability in detecting out-of-scope queries. \n",
"\n",
"Specifically, you will:\n",
"* Build BERT-SNGP, a SNGP-augmented [BERT](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2) model.\n",
"* Load the [CLINC Out-of-scope (OOS)](https://www.tensorflow.org/datasets/catalog/clinc_oos) intent detection dataset.\n",
"* Train the BERT-SNGP model.\n",
"* Evaluate the BERT-SNGP model's performance in uncertainty calibration and out-of-domain detection.\n",
"\n",
"Beyond CLINC OOS, the SNGP model has been applied to large-scale datasets such as [Jigsaw toxicity detection](https://www.tensorflow.org/datasets/catalog/wikipedia_toxicity_subtypes), and to the image datasets such as [CIFAR-100](https://www.tensorflow.org/datasets/catalog/cifar100) and [ImageNet](https://www.tensorflow.org/datasets/catalog/imagenet2012). \n",
"For benchmark results of SNGP and other uncertainty methods, as well as high-quality implementation with end-to-end training / evaluation scripts, you can check out the [Uncertainty Baselines](https://github.com/google/uncertainty-baselines) benchmark."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-bsids4eAYYI"
},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3sgnLBKk7iuR"
},
"outputs": [],
"source": [
"!pip install tf-models-nightly"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "M42dnVSk7dVy"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"import sklearn.metrics\n",
"import sklearn.calibration\n",
"\n",
"import tensorflow_hub as hub\n",
"import tensorflow_datasets as tfds\n",
"\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"\n",
"import official.nlp.modeling.layers as layers\n",
"import official.nlp.optimization as optimization"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cnRQfguq6GZj"
},
"source": [
"First implement a standard BERT classifier following the [classify text with BERT](https://www.tensorflow.org/tutorials/text/classify_text_with_bert) tutorial. We will use the [BERT-base](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3) encoder, and the built-in [`ClassificationHead`](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/cls_head.py) as the classifier."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "bNBEGs7s6NHB"
},
"outputs": [],
"source": [
"#@title Standard BERT model\n",
"\n",
"PREPROCESS_HANDLE = 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'\n",
"MODEL_HANDLE = 'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3'\n",
"\n",
"class BertClassifier(tf.keras.Model):\n",
" def __init__(self, \n",
" num_classes=150, inner_dim=768, dropout_rate=0.1,\n",
" **classifier_kwargs):\n",
" \n",
" super().__init__()\n",
" self.classifier_kwargs = classifier_kwargs\n",
"\n",
" # Initiate the BERT encoder components.\n",
" self.bert_preprocessor = hub.KerasLayer(PREPROCESS_HANDLE, name='preprocessing')\n",
" self.bert_hidden_layer = hub.KerasLayer(MODEL_HANDLE, trainable=True, name='bert_encoder')\n",
"\n",
" # Defines the encoder and classification layers.\n",
" self.bert_encoder = self.make_bert_encoder()\n",
" self.classifier = self.make_classification_head(num_classes, inner_dim, dropout_rate)\n",
"\n",
" def make_bert_encoder(self):\n",
" text_inputs = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')\n",
" encoder_inputs = self.bert_preprocessor(text_inputs)\n",
" encoder_outputs = self.bert_hidden_layer(encoder_inputs)\n",
" return tf.keras.Model(text_inputs, encoder_outputs)\n",
"\n",
" def make_classification_head(self, num_classes, inner_dim, dropout_rate):\n",
" return layers.ClassificationHead(\n",
" num_classes=num_classes, \n",
" inner_dim=inner_dim,\n",
" dropout_rate=dropout_rate,\n",
" **self.classifier_kwargs)\n",
"\n",
" def call(self, inputs, **kwargs):\n",
" encoder_outputs = self.bert_encoder(inputs)\n",
" classifier_inputs = encoder_outputs['sequence_output']\n",
" return self.classifier(classifier_inputs, **kwargs)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SbhbNbKk6WNR"
},
"source": [
"### Build SNGP model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "p7YakN0V6Oif"
},
"source": [
"To implement a BERT-SNGP model, you only need to replace the `ClassificationHead` with the built-in [`GaussianProcessClassificationHead`](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/cls_head.py). Spectral normalization is already pre-packaged into this classification head. Like in the [SNGP tutorial](https://www.tensorflow.org/tutorials/uncertainty/sngp), add a covariance reset callback to the model, so the model automatically reset the covariance estimator at the begining of a new epoch to avoid counting the same data twice."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QCaJy85y8WeE"
},
"outputs": [],
"source": [
"class ResetCovarianceCallback(tf.keras.callbacks.Callback):\n",
"\n",
" def on_epoch_begin(self, epoch, logs=None):\n",
" \"\"\"Resets covariance matrix at the begining of the epoch.\"\"\"\n",
" if epoch \u003e 0:\n",
" self.model.classifier.reset_covariance_matrix()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YoHgOuiZ6Q4y"
},
"outputs": [],
"source": [
"class SNGPBertClassifier(BertClassifier):\n",
"\n",
" def make_classification_head(self, num_classes, inner_dim, dropout_rate):\n",
" return layers.GaussianProcessClassificationHead(\n",
" num_classes=num_classes, \n",
" inner_dim=inner_dim,\n",
" dropout_rate=dropout_rate,\n",
" gp_cov_momentum=-1,\n",
" temperature=30.,\n",
" **self.classifier_kwargs)\n",
"\n",
" def fit(self, *args, **kwargs):\n",
" \"\"\"Adds ResetCovarianceCallback to model callbacks.\"\"\"\n",
" kwargs['callbacks'] = list(kwargs.get('callbacks', []))\n",
" kwargs['callbacks'].append(ResetCovarianceCallback())\n",
"\n",
" return super().fit(*args, **kwargs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UOj5YWTt6dCe"
},
"source": [
"Note: The `GaussianProcessClassificationHead` takes a new argument `temperature`. It corresponds to the $\\lambda$ parameter in the __mean-field approximation__ introduced in the [SNGP tutorial](https://www.tensorflow.org/tutorials/uncertainty/sngp). In practice, this value is usually treated as a hyperparamter, and is finetuned to optimize the model's calibration performance."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qdU90uDT6hFq"
},
"source": [
"### Load CLINC OOS dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AnuNeyHw6kH7"
},
"source": [
"Now load the [CLINC OOS](https://www.tensorflow.org/datasets/catalog/clinc_oos) intent detection dataset. This dataset contains 15000 user's spoken queries collected over 150 intent classes, it also contains 1000 out-of-domain (OOD) sentences that are not covered by any of the known classes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mkMZN2iA6hhg"
},
"outputs": [],
"source": [
"(clinc_train, clinc_test, clinc_test_oos), ds_info = tfds.load(\n",
" 'clinc_oos', split=['train', 'test', 'test_oos'], with_info=True, batch_size=-1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UJSL2nm8Bo02"
},
"source": [
"Make the train and test data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cgkOOZOq6fQL"
},
"outputs": [],
"source": [
"train_examples = clinc_train['text']\n",
"train_labels = clinc_train['intent']\n",
"\n",
"# Makes the in-domain (IND) evaluation data.\n",
"ind_eval_data = (clinc_test['text'], clinc_test['intent'])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Kw76f6caBq_E"
},
"source": [
"Create a OOD evaluation dataset. For this, combine the in-domain test data `clinc_test` and the out-of-domain data `clinc_test_oos`. We will also assign label 0 to the in-domain examples, and label 1 to the out-of-domain examples. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uVFuzecR64FJ"
},
"outputs": [],
"source": [
"test_data_size = ds_info.splits['test'].num_examples\n",
"oos_data_size = ds_info.splits['test_oos'].num_examples\n",
"\n",
"# Combines the in-domain and out-of-domain test examples.\n",
"oos_texts = tf.concat([clinc_test['text'], clinc_test_oos['text']], axis=0)\n",
"oos_labels = tf.constant([0] * test_data_size + [1] * oos_data_size)\n",
"\n",
"# Converts into a TF dataset.\n",
"ood_eval_dataset = tf.data.Dataset.from_tensor_slices(\n",
" {\"text\": oos_texts, \"label\": oos_labels})"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZcHwfwfU6qCE"
},
"source": [
"### Train and evaluate"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_VTY6KYc6sBB"
},
"source": [
"First set up the basic training configurations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_-uUkUtk6qWC"
},
"outputs": [],
"source": [
"TRAIN_EPOCHS = 3\n",
"TRAIN_BATCH_SIZE = 32\n",
"EVAL_BATCH_SIZE = 256"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "tiEjMdFV6wXQ"
},
"outputs": [],
"source": [
"#@title\n",
"\n",
"def bert_optimizer(learning_rate, \n",
" batch_size=TRAIN_BATCH_SIZE, epochs=TRAIN_EPOCHS, \n",
" warmup_rate=0.1):\n",
" \"\"\"Creates an AdamWeightDecay optimizer with learning rate schedule.\"\"\"\n",
" train_data_size = ds_info.splits['train'].num_examples\n",
" \n",
" steps_per_epoch = int(train_data_size / batch_size)\n",
" num_train_steps = steps_per_epoch * epochs\n",
" num_warmup_steps = int(warmup_rate * num_train_steps) \n",
"\n",
" # Creates learning schedule.\n",
" lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(\n",
" initial_learning_rate=learning_rate,\n",
" decay_steps=num_train_steps,\n",
" end_learning_rate=0.0) \n",
" \n",
" return optimization.AdamWeightDecay(\n",
" learning_rate=lr_schedule,\n",
" weight_decay_rate=0.01,\n",
" epsilon=1e-6,\n",
" exclude_from_weight_decay=['LayerNorm', 'layer_norm', 'bias'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KX_Hzl3l6w-H"
},
"outputs": [],
"source": [
"optimizer = bert_optimizer(learning_rate=1e-4)\n",
"loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
"metrics = tf.metrics.SparseCategoricalAccuracy()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ptn9Cupe6z7o"
},
"outputs": [],
"source": [
"fit_configs = dict(batch_size=TRAIN_BATCH_SIZE,\n",
" epochs=TRAIN_EPOCHS,\n",
" validation_batch_size=EVAL_BATCH_SIZE, \n",
" validation_data=ind_eval_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0ZK5PBwW61jd"
},
"outputs": [],
"source": [
"sngp_model = SNGPBertClassifier()\n",
"sngp_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)\n",
"sngp_model.fit(train_examples, train_labels, **fit_configs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cpDsgTYx63tO"
},
"source": [
"### Evaluate OOD performance"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d5NGVe7L67bB"
},
"source": [
"Evaluate how well the model can detect the unfamiliar out-of-domain queries. For rigorous evaluation, use the OOD evaluation dataset `ood_eval_dataset` built earlier."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "yyLgt_lL7APo"
},
"outputs": [],
"source": [
"#@title\n",
"\n",
"def oos_predict(model, ood_eval_dataset, **model_kwargs):\n",
" oos_labels = []\n",
" oos_probs = []\n",
"\n",
" ood_eval_dataset = ood_eval_dataset.batch(EVAL_BATCH_SIZE)\n",
" for oos_batch in ood_eval_dataset:\n",
" oos_text_batch = oos_batch[\"text\"]\n",
" oos_label_batch = oos_batch[\"label\"] \n",
"\n",
" pred_logits = model(oos_text_batch, **model_kwargs)\n",
" pred_probs_all = tf.nn.softmax(pred_logits, axis=-1)\n",
" pred_probs = tf.reduce_max(pred_probs_all, axis=-1)\n",
"\n",
" oos_labels.append(oos_label_batch)\n",
" oos_probs.append(pred_probs)\n",
"\n",
" oos_probs = tf.concat(oos_probs, axis=0)\n",
" oos_labels = tf.concat(oos_labels, axis=0) \n",
"\n",
" return oos_probs, oos_labels"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Dmc2tVXs6_uo"
},
"source": [
"Computes the OOD probabilities as $1 - p(x)$, where $p(x)=softmax(logit(x))$ is the predictive probability."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_9aFVVDO7C7o"
},
"outputs": [],
"source": [
"sngp_probs, ood_labels = oos_predict(sngp_model, ood_eval_dataset)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_PC0wwZp7GJD"
},
"outputs": [],
"source": [
"ood_probs = 1 - sngp_probs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AsandMTX7HjX"
},
"source": [
"Now evaluate how well the model's uncertainty score `ood_probs` predicts the out-of-domain label. First compute the Area under precision-recall curve (AUPRC) for OOD probability v.s. OOD detection accuracy."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0u5Wx8AP7Mdx"
},
"outputs": [],
"source": [
"precision, recall, _ = sklearn.metrics.precision_recall_curve(ood_labels, ood_probs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "axcctOsh7N5A"
},
"outputs": [],
"source": [
"auprc = sklearn.metrics.auc(recall, precision)\n",
"print(f'SNGP AUPRC: {auprc:.4f}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "U_GEqxq-7Q1Y"
},
"source": [
"This matches the SNGP performance reported at the CLINC OOS benchmark under the [Uncertainty Baselines](https://github.com/google/uncertainty-baselines)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8H4vYcyd7Ux2"
},
"source": [
"Next, examine the model's quality in [uncertainty calibration](https://scikit-learn.org/stable/modules/calibration.html), i.e., whether the model's predictive probability corresponds to its predictive accuracy. A well-calibrated model is considered trust-worthy, since, for example, its predictive probability $p(x)=0.8$ means that the model is correct 80% of the time."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "x5GxrSWJ7SYn"
},
"outputs": [],
"source": [
"prob_true, prob_pred = sklearn.calibration.calibration_curve(\n",
" ood_labels, ood_probs, n_bins=10, strategy='quantile')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ozzJM-D-7XVq"
},
"outputs": [],
"source": [
"plt.plot(prob_pred, prob_true)\n",
"\n",
"plt.plot([0., 1.], [0., 1.], c='k', linestyle=\"--\")\n",
"plt.xlabel('Predictive Probability')\n",
"plt.ylabel('Predictive Accuracy')\n",
"plt.title('Calibration Plots, SNGP')\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "36M6HeHx7ZI4"
},
"source": [
"## Resources and further reading"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xdFTpyaP0A-N"
},
"source": [
"* See the [SNGP tutorial](https://www.tensorflow.org/tutorials/uncertainty/sngp) for an detailed walkthrough of implementing SNGP from scratch. \n",
"* See [Uncertainty Baselines](https://github.com/google/uncertainty-baselines) for the implementation of SNGP model (and many other uncertainty methods) on a wide variety of benchmark datasets (e.g., [CIFAR](https://www.tensorflow.org/datasets/catalog/cifar100), [ImageNet](https://www.tensorflow.org/datasets/catalog/imagenet2012), [Jigsaw toxicity detection](https://www.tensorflow.org/datasets/catalog/wikipedia_toxicity_subtypes), etc).\n",
"* For a deeper understanding of the SNGP method, check out the paper [Simple and Principled Uncertainty Estimation with Deterministic Deep Learning via Distance Awareness](https://arxiv.org/abs/2006.10108).\n"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "uncertainty_quantification_with_sngp_bert.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment