" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/official_models/tutorials/uncertainty_quantification_with_sngp_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/colab/uncertainty_quantification_with_sngp_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/colab/uncertainty_quantification_with_sngp_bert.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
"In the [SNGP tutorial](https://www.tensorflow.org/tutorials/uncertainty/sngp), you learned how to build SNGP model on top of a deep residual network to improve its ability to quantify its uncertainty. In this tutorial, you will apply SNGP to a natural language understanding (NLU) task by building it on top of a deep BERT encoder to improve deep NLU model's ability in detecting out-of-scope queries. \n",
"\n",
"Specifically, you will:\n",
"* Build BERT-SNGP, a SNGP-augmented [BERT](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2) model.\n",
"* Load the [CLINC Out-of-scope (OOS)](https://www.tensorflow.org/datasets/catalog/clinc_oos) intent detection dataset.\n",
"* Train the BERT-SNGP model.\n",
"* Evaluate the BERT-SNGP model's performance in uncertainty calibration and out-of-domain detection.\n",
"\n",
"Beyond CLINC OOS, the SNGP model has been applied to large-scale datasets such as [Jigsaw toxicity detection](https://www.tensorflow.org/datasets/catalog/wikipedia_toxicity_subtypes), and to the image datasets such as [CIFAR-100](https://www.tensorflow.org/datasets/catalog/cifar100) and [ImageNet](https://www.tensorflow.org/datasets/catalog/imagenet2012). \n",
"For benchmark results of SNGP and other uncertainty methods, as well as high-quality implementation with end-to-end training / evaluation scripts, you can check out the [Uncertainty Baselines](https://github.com/google/uncertainty-baselines) benchmark."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-bsids4eAYYI"
},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3sgnLBKk7iuR"
},
"outputs": [],
"source": [
"!pip install tf-models-nightly"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "M42dnVSk7dVy"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"import sklearn.metrics\n",
"import sklearn.calibration\n",
"\n",
"import tensorflow_hub as hub\n",
"import tensorflow_datasets as tfds\n",
"\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"\n",
"import official.nlp.modeling.layers as layers\n",
"import official.nlp.optimization as optimization"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cnRQfguq6GZj"
},
"source": [
"First implement a standard BERT classifier following the [classify text with BERT](https://www.tensorflow.org/tutorials/text/classify_text_with_bert) tutorial. We will use the [BERT-base](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3) encoder, and the built-in [`ClassificationHead`](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/cls_head.py) as the classifier."
"To implement a BERT-SNGP model, you only need to replace the `ClassificationHead` with the built-in [`GaussianProcessClassificationHead`](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/cls_head.py). Spectral normalization is already pre-packaged into this classification head. Like in the [SNGP tutorial](https://www.tensorflow.org/tutorials/uncertainty/sngp), add a covariance reset callback to the model, so the model automatically reset the covariance estimator at the begining of a new epoch to avoid counting the same data twice."
"Note: The `GaussianProcessClassificationHead` takes a new argument `temperature`. It corresponds to the $\\lambda$ parameter in the __mean-field approximation__ introduced in the [SNGP tutorial](https://www.tensorflow.org/tutorials/uncertainty/sngp). In practice, this value is usually treated as a hyperparamter, and is finetuned to optimize the model's calibration performance."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qdU90uDT6hFq"
},
"source": [
"### Load CLINC OOS dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AnuNeyHw6kH7"
},
"source": [
"Now load the [CLINC OOS](https://www.tensorflow.org/datasets/catalog/clinc_oos) intent detection dataset. This dataset contains 15000 user's spoken queries collected over 150 intent classes, it also contains 1000 out-of-domain (OOD) sentences that are not covered by any of the known classes."
"Create a OOD evaluation dataset. For this, combine the in-domain test data `clinc_test` and the out-of-domain data `clinc_test_oos`. We will also assign label 0 to the in-domain examples, and label 1 to the out-of-domain examples. "
"Evaluate how well the model can detect the unfamiliar out-of-domain queries. For rigorous evaluation, use the OOD evaluation dataset `ood_eval_dataset` built earlier."
"Now evaluate how well the model's uncertainty score `ood_probs` predicts the out-of-domain label. First compute the Area under precision-recall curve (AUPRC) for OOD probability v.s. OOD detection accuracy."
"This matches the SNGP performance reported at the CLINC OOS benchmark under the [Uncertainty Baselines](https://github.com/google/uncertainty-baselines)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8H4vYcyd7Ux2"
},
"source": [
"Next, examine the model's quality in [uncertainty calibration](https://scikit-learn.org/stable/modules/calibration.html), i.e., whether the model's predictive probability corresponds to its predictive accuracy. A well-calibrated model is considered trust-worthy, since, for example, its predictive probability $p(x)=0.8$ means that the model is correct 80% of the time."
"* See the [SNGP tutorial](https://www.tensorflow.org/tutorials/uncertainty/sngp) for an detailed walkthrough of implementing SNGP from scratch. \n",
"* See [Uncertainty Baselines](https://github.com/google/uncertainty-baselines) for the implementation of SNGP model (and many other uncertainty methods) on a wide variety of benchmark datasets (e.g., [CIFAR](https://www.tensorflow.org/datasets/catalog/cifar100), [ImageNet](https://www.tensorflow.org/datasets/catalog/imagenet2012), [Jigsaw toxicity detection](https://www.tensorflow.org/datasets/catalog/wikipedia_toxicity_subtypes), etc).\n",
"* For a deeper understanding of the SNGP method, check out the paper [Simple and Principled Uncertainty Estimation with Deterministic Deep Learning via Distance Awareness](https://arxiv.org/abs/2006.10108).\n"