"tests/vscode:/vscode.git/clone" did not exist on "4f6399bedd28a9af0e18b1ebf84c38eb07ffa769"
Commit 2e9bb539 authored by stephenwu's avatar stephenwu
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into RTESuperGLUE

parents 7bae5317 8fba84f8
...@@ -19,7 +19,7 @@ In the near future, we will add: ...@@ -19,7 +19,7 @@ In the near future, we will add:
* State-of-the-art language understanding models. * State-of-the-art language understanding models.
* State-of-the-art image classification models. * State-of-the-art image classification models.
* State-of-the-art objection detection and instance segmentation models. * State-of-the-art object detection and instance segmentation models.
## Table of Contents ## Table of Contents
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "vXLA5InzXydn" "id": "vXLA5InzXydn"
}, },
"source": [ "source": [
...@@ -15,8 +14,6 @@ ...@@ -15,8 +14,6 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form", "cellView": "form",
"colab": {},
"colab_type": "code",
"id": "RuRlpLL-X0R_" "id": "RuRlpLL-X0R_"
}, },
"outputs": [], "outputs": [],
...@@ -37,7 +34,6 @@ ...@@ -37,7 +34,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "1mLJmVotXs64" "id": "1mLJmVotXs64"
}, },
"source": [ "source": [
...@@ -47,7 +43,6 @@ ...@@ -47,7 +43,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "hYEwGTeCXnnX" "id": "hYEwGTeCXnnX"
}, },
"source": [ "source": [
...@@ -73,7 +68,6 @@ ...@@ -73,7 +68,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "YN2ACivEPxgD" "id": "YN2ACivEPxgD"
}, },
"source": [ "source": [
...@@ -85,7 +79,6 @@ ...@@ -85,7 +79,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "s2d9S2CSSO1z" "id": "s2d9S2CSSO1z"
}, },
"source": [ "source": [
...@@ -95,7 +88,6 @@ ...@@ -95,7 +88,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "fsACVQpVSifi" "id": "fsACVQpVSifi"
}, },
"source": [ "source": [
...@@ -110,19 +102,16 @@ ...@@ -110,19 +102,16 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "NvNr2svBM-p3" "id": "NvNr2svBM-p3"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"!pip install -q tf-models-official==2.3.0" "!pip install -q tf-models-official==2.4.0"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "U-7qPCjWUAyy" "id": "U-7qPCjWUAyy"
}, },
"source": [ "source": [
...@@ -133,8 +122,6 @@ ...@@ -133,8 +122,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "lXsXev5MNr20" "id": "lXsXev5MNr20"
}, },
"outputs": [], "outputs": [],
...@@ -163,13 +150,12 @@ ...@@ -163,13 +150,12 @@
"import official.nlp.data.classifier_data_lib\n", "import official.nlp.data.classifier_data_lib\n",
"import official.nlp.modeling.losses\n", "import official.nlp.modeling.losses\n",
"import official.nlp.modeling.models\n", "import official.nlp.modeling.models\n",
"import official.nlp.modeling.networks" "import official.nlp.modeling.networks\n"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "mbanlzTvJBsz" "id": "mbanlzTvJBsz"
}, },
"source": [ "source": [
...@@ -179,7 +165,6 @@ ...@@ -179,7 +165,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "PpW0x8TpR8DT" "id": "PpW0x8TpR8DT"
}, },
"source": [ "source": [
...@@ -190,20 +175,17 @@ ...@@ -190,20 +175,17 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "vzRHOLciR8eq" "id": "vzRHOLciR8eq"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"gs_folder_bert = \"gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12\"\n", "gs_folder_bert = \"gs://cloud-tpu-checkpoints/bert/v3/uncased_L-12_H-768_A-12\"\n",
"tf.io.gfile.listdir(gs_folder_bert)" "tf.io.gfile.listdir(gs_folder_bert)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "9uFskufsR2LT" "id": "9uFskufsR2LT"
}, },
"source": [ "source": [
...@@ -214,19 +196,16 @@ ...@@ -214,19 +196,16 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "e0dAkUttJAzj" "id": "e0dAkUttJAzj"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"hub_url_bert = \"https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2\"" "hub_url_bert = \"https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3\""
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "Qv6abtRvH4xO" "id": "Qv6abtRvH4xO"
}, },
"source": [ "source": [
...@@ -239,7 +218,6 @@ ...@@ -239,7 +218,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "28DvUhC1YUiB" "id": "28DvUhC1YUiB"
}, },
"source": [ "source": [
...@@ -257,8 +235,6 @@ ...@@ -257,8 +235,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "Ijikx5OsH9AT" "id": "Ijikx5OsH9AT"
}, },
"outputs": [], "outputs": [],
...@@ -272,8 +248,6 @@ ...@@ -272,8 +248,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "xf9zz4vLYXjr" "id": "xf9zz4vLYXjr"
}, },
"outputs": [], "outputs": [],
...@@ -284,7 +258,6 @@ ...@@ -284,7 +258,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "ZgBg2r2nYT-K" "id": "ZgBg2r2nYT-K"
}, },
"source": [ "source": [
...@@ -295,8 +268,6 @@ ...@@ -295,8 +268,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "IQrHxv7W7jH5" "id": "IQrHxv7W7jH5"
}, },
"outputs": [], "outputs": [],
...@@ -307,7 +278,6 @@ ...@@ -307,7 +278,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "vhsVWYNxazz5" "id": "vhsVWYNxazz5"
}, },
"source": [ "source": [
...@@ -318,8 +288,6 @@ ...@@ -318,8 +288,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "n0gfc_VTayfQ" "id": "n0gfc_VTayfQ"
}, },
"outputs": [], "outputs": [],
...@@ -330,7 +298,6 @@ ...@@ -330,7 +298,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "38zJcap6xkbC" "id": "38zJcap6xkbC"
}, },
"source": [ "source": [
...@@ -341,8 +308,6 @@ ...@@ -341,8 +308,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "xON_i6SkwApW" "id": "xON_i6SkwApW"
}, },
"outputs": [], "outputs": [],
...@@ -356,7 +321,6 @@ ...@@ -356,7 +321,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "9fbTyfJpNr7x" "id": "9fbTyfJpNr7x"
}, },
"source": [ "source": [
...@@ -366,7 +330,6 @@ ...@@ -366,7 +330,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "wqeN54S61ZKQ" "id": "wqeN54S61ZKQ"
}, },
"source": [ "source": [
...@@ -381,8 +344,6 @@ ...@@ -381,8 +344,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "idxyhmrCQcw5" "id": "idxyhmrCQcw5"
}, },
"outputs": [], "outputs": [],
...@@ -398,7 +359,6 @@ ...@@ -398,7 +359,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "zYHDSquU2lDU" "id": "zYHDSquU2lDU"
}, },
"source": [ "source": [
...@@ -409,8 +369,6 @@ ...@@ -409,8 +369,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "L_OfOYPg853R" "id": "L_OfOYPg853R"
}, },
"outputs": [], "outputs": [],
...@@ -424,7 +382,6 @@ ...@@ -424,7 +382,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "kkAXLtuyWWDI" "id": "kkAXLtuyWWDI"
}, },
"source": [ "source": [
...@@ -438,7 +395,6 @@ ...@@ -438,7 +395,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "62UTWLQd9-LB" "id": "62UTWLQd9-LB"
}, },
"source": [ "source": [
...@@ -451,8 +407,6 @@ ...@@ -451,8 +407,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "bdL-dRNRBRJT" "id": "bdL-dRNRBRJT"
}, },
"outputs": [], "outputs": [],
...@@ -463,7 +417,6 @@ ...@@ -463,7 +417,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "UrPktnqpwqie" "id": "UrPktnqpwqie"
}, },
"source": [ "source": [
...@@ -474,8 +427,6 @@ ...@@ -474,8 +427,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "BR7BmtU498Bh" "id": "BR7BmtU498Bh"
}, },
"outputs": [], "outputs": [],
...@@ -495,8 +446,6 @@ ...@@ -495,8 +446,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "has42aUdfky-" "id": "has42aUdfky-"
}, },
"outputs": [], "outputs": [],
...@@ -508,7 +457,6 @@ ...@@ -508,7 +457,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "MU9lTWy_xXbb" "id": "MU9lTWy_xXbb"
}, },
"source": [ "source": [
...@@ -519,8 +467,6 @@ ...@@ -519,8 +467,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "USD8uihw-g4J" "id": "USD8uihw-g4J"
}, },
"outputs": [], "outputs": [],
...@@ -533,7 +479,6 @@ ...@@ -533,7 +479,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "xmNv4l4k-dBZ" "id": "xmNv4l4k-dBZ"
}, },
"source": [ "source": [
...@@ -543,7 +488,6 @@ ...@@ -543,7 +488,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "DIWjNIKq-ldh" "id": "DIWjNIKq-ldh"
}, },
"source": [ "source": [
...@@ -556,7 +500,6 @@ ...@@ -556,7 +500,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "ulNZ4U96-8JZ" "id": "ulNZ4U96-8JZ"
}, },
"source": [ "source": [
...@@ -567,8 +510,6 @@ ...@@ -567,8 +510,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "EezOO9qj91kP" "id": "EezOO9qj91kP"
}, },
"outputs": [], "outputs": [],
...@@ -581,7 +522,6 @@ ...@@ -581,7 +522,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "rxLenwAvCkBf" "id": "rxLenwAvCkBf"
}, },
"source": [ "source": [
...@@ -592,8 +532,6 @@ ...@@ -592,8 +532,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "2CetH_5C9P2m" "id": "2CetH_5C9P2m"
}, },
"outputs": [], "outputs": [],
...@@ -609,7 +547,6 @@ ...@@ -609,7 +547,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "P5UBnCn8Ii6s" "id": "P5UBnCn8Ii6s"
}, },
"source": [ "source": [
...@@ -622,8 +559,6 @@ ...@@ -622,8 +559,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "sDGiWYPLEd5a" "id": "sDGiWYPLEd5a"
}, },
"outputs": [], "outputs": [],
...@@ -666,8 +601,6 @@ ...@@ -666,8 +601,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "yuLKxf6zHxw-" "id": "yuLKxf6zHxw-"
}, },
"outputs": [], "outputs": [],
...@@ -685,7 +618,6 @@ ...@@ -685,7 +618,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "7FC5aLVxKVKK" "id": "7FC5aLVxKVKK"
}, },
"source": [ "source": [
...@@ -696,8 +628,6 @@ ...@@ -696,8 +628,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "jyjTdGpFhO_1" "id": "jyjTdGpFhO_1"
}, },
"outputs": [], "outputs": [],
...@@ -711,7 +641,6 @@ ...@@ -711,7 +641,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "FSwymsbkbLDA" "id": "FSwymsbkbLDA"
}, },
"source": [ "source": [
...@@ -721,7 +650,6 @@ ...@@ -721,7 +650,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "Efrj3Cn1kLAp" "id": "Efrj3Cn1kLAp"
}, },
"source": [ "source": [
...@@ -731,7 +659,6 @@ ...@@ -731,7 +659,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "xxpOY5r2Ayq6" "id": "xxpOY5r2Ayq6"
}, },
"source": [ "source": [
...@@ -742,8 +669,6 @@ ...@@ -742,8 +669,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "ujapVfZ_AKW7" "id": "ujapVfZ_AKW7"
}, },
"outputs": [], "outputs": [],
...@@ -761,7 +686,6 @@ ...@@ -761,7 +686,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "96ldxDSwkVkj" "id": "96ldxDSwkVkj"
}, },
"source": [ "source": [
...@@ -774,8 +698,6 @@ ...@@ -774,8 +698,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "cH682__U0FBv" "id": "cH682__U0FBv"
}, },
"outputs": [], "outputs": [],
...@@ -787,7 +709,6 @@ ...@@ -787,7 +709,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "XqKp3-5GIZlw" "id": "XqKp3-5GIZlw"
}, },
"source": [ "source": [
...@@ -798,8 +719,6 @@ ...@@ -798,8 +719,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "bAQblMIjwkvx" "id": "bAQblMIjwkvx"
}, },
"outputs": [], "outputs": [],
...@@ -810,7 +729,6 @@ ...@@ -810,7 +729,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "sFmVG4SKZAw8" "id": "sFmVG4SKZAw8"
}, },
"source": [ "source": [
...@@ -821,8 +739,6 @@ ...@@ -821,8 +739,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "VTjgPbp4ZDKo" "id": "VTjgPbp4ZDKo"
}, },
"outputs": [], "outputs": [],
...@@ -837,7 +753,6 @@ ...@@ -837,7 +753,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "Q0NTdwZsQK8n" "id": "Q0NTdwZsQK8n"
}, },
"source": [ "source": [
...@@ -850,8 +765,6 @@ ...@@ -850,8 +765,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "8L__-erBwLIQ" "id": "8L__-erBwLIQ"
}, },
"outputs": [], "outputs": [],
...@@ -862,7 +775,6 @@ ...@@ -862,7 +775,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "mKAvkQc3heSy" "id": "mKAvkQc3heSy"
}, },
"source": [ "source": [
...@@ -875,21 +787,18 @@ ...@@ -875,21 +787,18 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "97Ll2Gichd_Y" "id": "97Ll2Gichd_Y"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"checkpoint = tf.train.Checkpoint(model=bert_encoder)\n", "checkpoint = tf.train.Checkpoint(encoder=bert_encoder)\n",
"checkpoint.restore(\n", "checkpoint.read(\n",
" os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()" " os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "2oHOql35k3Dd" "id": "2oHOql35k3Dd"
}, },
"source": [ "source": [
...@@ -899,7 +808,6 @@ ...@@ -899,7 +808,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "115caFLMk-_l" "id": "115caFLMk-_l"
}, },
"source": [ "source": [
...@@ -913,8 +821,6 @@ ...@@ -913,8 +821,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "w8qXKRZuCwW4" "id": "w8qXKRZuCwW4"
}, },
"outputs": [], "outputs": [],
...@@ -937,7 +843,6 @@ ...@@ -937,7 +843,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "pXRGxiRNEHS2" "id": "pXRGxiRNEHS2"
}, },
"source": [ "source": [
...@@ -948,8 +853,6 @@ ...@@ -948,8 +853,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "eQNA16bhDpky" "id": "eQNA16bhDpky"
}, },
"outputs": [], "outputs": [],
...@@ -960,7 +863,6 @@ ...@@ -960,7 +863,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "xqu_K71fJQB8" "id": "xqu_K71fJQB8"
}, },
"source": [ "source": [
...@@ -970,7 +872,6 @@ ...@@ -970,7 +872,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "78FEUOOEkoP0" "id": "78FEUOOEkoP0"
}, },
"source": [ "source": [
...@@ -980,7 +881,6 @@ ...@@ -980,7 +881,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "OTNcA0O0nSq9" "id": "OTNcA0O0nSq9"
}, },
"source": [ "source": [
...@@ -991,8 +891,6 @@ ...@@ -991,8 +891,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "nzi8hjeTQTRs" "id": "nzi8hjeTQTRs"
}, },
"outputs": [], "outputs": [],
...@@ -1015,7 +913,6 @@ ...@@ -1015,7 +913,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "IFtKFWbNKb0u" "id": "IFtKFWbNKb0u"
}, },
"source": [ "source": [
...@@ -1028,8 +925,6 @@ ...@@ -1028,8 +925,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "9ZoUgDUNJPz3" "id": "9ZoUgDUNJPz3"
}, },
"outputs": [], "outputs": [],
...@@ -1049,7 +944,6 @@ ...@@ -1049,7 +944,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "7ynJibkBRTJF" "id": "7ynJibkBRTJF"
}, },
"source": [ "source": [
...@@ -1060,8 +954,6 @@ ...@@ -1060,8 +954,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "umo0ttrgRYIM" "id": "umo0ttrgRYIM"
}, },
"outputs": [], "outputs": [],
...@@ -1076,8 +968,6 @@ ...@@ -1076,8 +968,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "utGl0M3aZCE4" "id": "utGl0M3aZCE4"
}, },
"outputs": [], "outputs": [],
...@@ -1088,7 +978,6 @@ ...@@ -1088,7 +978,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "fVo_AnT0l26j" "id": "fVo_AnT0l26j"
}, },
"source": [ "source": [
...@@ -1101,8 +990,6 @@ ...@@ -1101,8 +990,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "Nl5x6nElZqkP" "id": "Nl5x6nElZqkP"
}, },
"outputs": [], "outputs": [],
...@@ -1115,8 +1002,7 @@ ...@@ -1115,8 +1002,7 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "collapsed": true,
"colab_type": "code",
"id": "y_ACvKPsVUXC" "id": "y_ACvKPsVUXC"
}, },
"outputs": [], "outputs": [],
...@@ -1137,7 +1023,6 @@ ...@@ -1137,7 +1023,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "eQceYqRFT_Eg" "id": "eQceYqRFT_Eg"
}, },
"source": [ "source": [
...@@ -1147,7 +1032,6 @@ ...@@ -1147,7 +1032,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "SaC1RlFawUpc" "id": "SaC1RlFawUpc"
}, },
"source": [ "source": [
...@@ -1158,7 +1042,6 @@ ...@@ -1158,7 +1042,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "CwUdjFBkzUgh" "id": "CwUdjFBkzUgh"
}, },
"source": [ "source": [
...@@ -1170,7 +1053,6 @@ ...@@ -1170,7 +1053,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "2UTQrkyOT5wD" "id": "2UTQrkyOT5wD"
}, },
"source": [ "source": [
...@@ -1181,8 +1063,6 @@ ...@@ -1181,8 +1063,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "XQeDFOzYR9Z9" "id": "XQeDFOzYR9Z9"
}, },
"outputs": [], "outputs": [],
...@@ -1195,7 +1075,6 @@ ...@@ -1195,7 +1075,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "XrFQbfErUWxa" "id": "XrFQbfErUWxa"
}, },
"source": [ "source": [
...@@ -1206,8 +1085,6 @@ ...@@ -1206,8 +1085,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "ymw7GOHpSHKU" "id": "ymw7GOHpSHKU"
}, },
"outputs": [], "outputs": [],
...@@ -1234,7 +1111,6 @@ ...@@ -1234,7 +1111,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "uX_Sp-wTUoRm" "id": "uX_Sp-wTUoRm"
}, },
"source": [ "source": [
...@@ -1245,8 +1121,6 @@ ...@@ -1245,8 +1121,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "rkHxIK57SQ_r" "id": "rkHxIK57SQ_r"
}, },
"outputs": [], "outputs": [],
...@@ -1267,7 +1141,6 @@ ...@@ -1267,7 +1141,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "stbaVouogvzS" "id": "stbaVouogvzS"
}, },
"source": [ "source": [
...@@ -1278,8 +1151,6 @@ ...@@ -1278,8 +1151,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "gwhrlQl4gxVF" "id": "gwhrlQl4gxVF"
}, },
"outputs": [], "outputs": [],
...@@ -1290,7 +1161,6 @@ ...@@ -1290,7 +1161,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "dbJ76vSJj77j" "id": "dbJ76vSJj77j"
}, },
"source": [ "source": [
...@@ -1300,7 +1170,6 @@ ...@@ -1300,7 +1170,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "9J95LFRohiYw" "id": "9J95LFRohiYw"
}, },
"source": [ "source": [
...@@ -1311,8 +1180,6 @@ ...@@ -1311,8 +1180,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "gCvaLLAxPuMc" "id": "gCvaLLAxPuMc"
}, },
"outputs": [], "outputs": [],
...@@ -1356,8 +1223,6 @@ ...@@ -1356,8 +1223,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "rutkBadrhzdR" "id": "rutkBadrhzdR"
}, },
"outputs": [], "outputs": [],
...@@ -1384,8 +1249,6 @@ ...@@ -1384,8 +1249,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "59TVgt4Z7fuU" "id": "59TVgt4Z7fuU"
}, },
"outputs": [], "outputs": [],
...@@ -1396,7 +1259,6 @@ ...@@ -1396,7 +1259,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "QbklKt-w_CiI" "id": "QbklKt-w_CiI"
}, },
"source": [ "source": [
...@@ -1411,8 +1273,6 @@ ...@@ -1411,8 +1273,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "GDWrHm0BGpbX" "id": "GDWrHm0BGpbX"
}, },
"outputs": [], "outputs": [],
...@@ -1426,8 +1286,6 @@ ...@@ -1426,8 +1286,6 @@
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"cellView": "form", "cellView": "form",
"colab": {},
"colab_type": "code",
"id": "Y29meH0qGq_5" "id": "Y29meH0qGq_5"
}, },
"outputs": [], "outputs": [],
...@@ -1439,13 +1297,11 @@ ...@@ -1439,13 +1297,11 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "lo6479At4sP1" "id": "lo6479At4sP1"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"hub_encoder = hub.KerasLayer(f\"https://tfhub.dev/tensorflow/{hub_model_name}/2\",\n", "hub_encoder = hub.KerasLayer(f\"https://tfhub.dev/tensorflow/{hub_model_name}/3\",\n",
" trainable=True)\n", " trainable=True)\n",
"\n", "\n",
"print(f\"The Hub encoder has {len(hub_encoder.trainable_variables)} trainable variables\")" "print(f\"The Hub encoder has {len(hub_encoder.trainable_variables)} trainable variables\")"
...@@ -1454,7 +1310,6 @@ ...@@ -1454,7 +1310,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "iTzF574wivQv" "id": "iTzF574wivQv"
}, },
"source": [ "source": [
...@@ -1465,27 +1320,25 @@ ...@@ -1465,27 +1320,25 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "XEcYrCR45Uwo" "id": "XEcYrCR45Uwo"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"result = hub_encoder(\n", "result = hub_encoder(\n",
" inputs=[glue_train['input_word_ids'][:10],\n", " inputs=dict(\n",
" glue_train['input_mask'][:10],\n", " input_word_ids=glue_train['input_word_ids'][:10],\n",
" glue_train['input_type_ids'][:10],],\n", " input_mask=glue_train['input_mask'][:10],\n",
" input_type_ids=glue_train['input_type_ids'][:10],),\n",
" training=False,\n", " training=False,\n",
")\n", ")\n",
"\n", "\n",
"print(\"Pooled output shape:\", result[0].shape)\n", "print(\"Pooled output shape:\", result['pooled_output'].shape)\n",
"print(\"Sequence output shape:\", result[1].shape)" "print(\"Sequence output shape:\", result['sequence_output'].shape)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "cjojn8SmLSRI" "id": "cjojn8SmLSRI"
}, },
"source": [ "source": [
...@@ -1498,33 +1351,31 @@ ...@@ -1498,33 +1351,31 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "9nTDaApyLR70" "id": "9nTDaApyLR70"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"hub_classifier, hub_encoder = bert.bert_models.classifier_model(\n", "hub_classifier = nlp.modeling.models.BertClassifier(\n",
" # Caution: Most of `bert_config` is ignored if you pass a hub url.\n", " bert_encoder,\n",
" bert_config=bert_config, hub_module_url=hub_url_bert, num_labels=2)" " num_classes=2,\n",
" dropout_rate=0.1,\n",
" initializer=tf.keras.initializers.TruncatedNormal(\n",
" stddev=0.02))"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "xMJX3wV0_v7I" "id": "xMJX3wV0_v7I"
}, },
"source": [ "source": [
"The one downside to loading this model from TFHub is that the structure of internal keras layers is not restored. So it's more difficult to inspect or modify the model. The `TransformerEncoder` model is now a single layer:" "The one downside to loading this model from TFHub is that the structure of internal keras layers is not restored. So it's more difficult to inspect or modify the model. The `BertEncoder` model is now a single layer:"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "pD71dnvhM2QS" "id": "pD71dnvhM2QS"
}, },
"outputs": [], "outputs": [],
...@@ -1536,8 +1387,6 @@ ...@@ -1536,8 +1387,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "nLZD-isBzNKi" "id": "nLZD-isBzNKi"
}, },
"outputs": [], "outputs": [],
...@@ -1552,7 +1401,6 @@ ...@@ -1552,7 +1401,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "ZxSqH0dNAgXV" "id": "ZxSqH0dNAgXV"
}, },
"source": [ "source": [
...@@ -1560,13 +1408,12 @@ ...@@ -1560,13 +1408,12 @@
"\n", "\n",
"### Low level model building\n", "### Low level model building\n",
"\n", "\n",
"If you need a more control over the construction of the model it's worth noting that the `classifier_model` function used earlier is really just a thin wrapper over the `nlp.modeling.networks.TransformerEncoder` and `nlp.modeling.models.BertClassifier` classes. Just remember that if you start modifying the architecture it may not be correct or possible to reload the pre-trained checkpoint so you'll need to retrain from scratch." "If you need a more control over the construction of the model it's worth noting that the `classifier_model` function used earlier is really just a thin wrapper over the `nlp.modeling.networks.BertEncoder` and `nlp.modeling.models.BertClassifier` classes. Just remember that if you start modifying the architecture it may not be correct or possible to reload the pre-trained checkpoint so you'll need to retrain from scratch."
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "0cgABEwDj06P" "id": "0cgABEwDj06P"
}, },
"source": [ "source": [
...@@ -1577,43 +1424,38 @@ ...@@ -1577,43 +1424,38 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "5r_yqhBFSVEM" "id": "5r_yqhBFSVEM"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"transformer_config = config_dict.copy()\n", "bert_encoder_config = config_dict.copy()\n",
"\n", "\n",
"# You need to rename a few fields to make this work:\n", "# You need to rename a few fields to make this work:\n",
"transformer_config['attention_dropout_rate'] = transformer_config.pop('attention_probs_dropout_prob')\n", "bert_encoder_config['attention_dropout_rate'] = bert_encoder_config.pop('attention_probs_dropout_prob')\n",
"transformer_config['activation'] = tf_utils.get_activation(transformer_config.pop('hidden_act'))\n", "bert_encoder_config['activation'] = tf_utils.get_activation(bert_encoder_config.pop('hidden_act'))\n",
"transformer_config['dropout_rate'] = transformer_config.pop('hidden_dropout_prob')\n", "bert_encoder_config['dropout_rate'] = bert_encoder_config.pop('hidden_dropout_prob')\n",
"transformer_config['initializer'] = tf.keras.initializers.TruncatedNormal(\n", "bert_encoder_config['initializer'] = tf.keras.initializers.TruncatedNormal(\n",
" stddev=transformer_config.pop('initializer_range'))\n", " stddev=bert_encoder_config.pop('initializer_range'))\n",
"transformer_config['max_sequence_length'] = transformer_config.pop('max_position_embeddings')\n", "bert_encoder_config['max_sequence_length'] = bert_encoder_config.pop('max_position_embeddings')\n",
"transformer_config['num_layers'] = transformer_config.pop('num_hidden_layers')\n", "bert_encoder_config['num_layers'] = bert_encoder_config.pop('num_hidden_layers')\n",
"\n", "\n",
"transformer_config" "bert_encoder_config"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "rIO8MI7LLijh" "id": "rIO8MI7LLijh"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"manual_encoder = nlp.modeling.networks.TransformerEncoder(**transformer_config)" "manual_encoder = nlp.modeling.networks.BertEncoder(**bert_encoder_config)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "4a4tFSg9krRi" "id": "4a4tFSg9krRi"
}, },
"source": [ "source": [
...@@ -1624,21 +1466,18 @@ ...@@ -1624,21 +1466,18 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "X6N9NEqfXJCx" "id": "X6N9NEqfXJCx"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"checkpoint = tf.train.Checkpoint(model=manual_encoder)\n", "checkpoint = tf.train.Checkpoint(encoder=manual_encoder)\n",
"checkpoint.restore(\n", "checkpoint.read(\n",
" os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()" " os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "1BPiPO4ykuwM" "id": "1BPiPO4ykuwM"
}, },
"source": [ "source": [
...@@ -1649,8 +1488,6 @@ ...@@ -1649,8 +1488,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "hlVdgJKmj389" "id": "hlVdgJKmj389"
}, },
"outputs": [], "outputs": [],
...@@ -1664,7 +1501,6 @@ ...@@ -1664,7 +1501,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "nJMXvVgJkyBv" "id": "nJMXvVgJkyBv"
}, },
"source": [ "source": [
...@@ -1675,8 +1511,6 @@ ...@@ -1675,8 +1511,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "tQX57GJ6wkAb" "id": "tQX57GJ6wkAb"
}, },
"outputs": [], "outputs": [],
...@@ -1684,17 +1518,14 @@ ...@@ -1684,17 +1518,14 @@
"manual_classifier = nlp.modeling.models.BertClassifier(\n", "manual_classifier = nlp.modeling.models.BertClassifier(\n",
" bert_encoder,\n", " bert_encoder,\n",
" num_classes=2,\n", " num_classes=2,\n",
" dropout_rate=transformer_config['dropout_rate'],\n", " dropout_rate=bert_encoder_config['dropout_rate'],\n",
" initializer=tf.keras.initializers.TruncatedNormal(\n", " initializer=bert_encoder_config['initializer'])"
" stddev=bert_config.initializer_range))"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "kB-nBWhQk0dS" "id": "kB-nBWhQk0dS"
}, },
"outputs": [], "outputs": [],
...@@ -1705,7 +1536,6 @@ ...@@ -1705,7 +1536,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "E6AJlOSyIO1L" "id": "E6AJlOSyIO1L"
}, },
"source": [ "source": [
...@@ -1720,8 +1550,6 @@ ...@@ -1720,8 +1550,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "28Dv3BPRlFTD" "id": "28Dv3BPRlFTD"
}, },
"outputs": [], "outputs": [],
...@@ -1733,7 +1561,6 @@ ...@@ -1733,7 +1561,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "LRjcHr0UlT8c" "id": "LRjcHr0UlT8c"
}, },
"source": [ "source": [
...@@ -1746,8 +1573,6 @@ ...@@ -1746,8 +1573,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "MHY8K6kDngQn" "id": "MHY8K6kDngQn"
}, },
"outputs": [], "outputs": [],
...@@ -1765,8 +1590,7 @@ ...@@ -1765,8 +1590,7 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "collapsed": true,
"colab_type": "code",
"id": "wKIcSprulu3P" "id": "wKIcSprulu3P"
}, },
"outputs": [], "outputs": [],
...@@ -1782,7 +1606,6 @@ ...@@ -1782,7 +1606,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "IMTC_gfAl_PZ" "id": "IMTC_gfAl_PZ"
}, },
"source": [ "source": [
...@@ -1793,8 +1616,6 @@ ...@@ -1793,8 +1616,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "YRt3VTmBmCBY" "id": "YRt3VTmBmCBY"
}, },
"outputs": [], "outputs": [],
...@@ -1816,7 +1637,6 @@ ...@@ -1816,7 +1637,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "l8D9Lv3Bn740" "id": "l8D9Lv3Bn740"
}, },
"source": [ "source": [
...@@ -1827,8 +1647,6 @@ ...@@ -1827,8 +1647,6 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "2Hf2rpRXk89N" "id": "2Hf2rpRXk89N"
}, },
"outputs": [], "outputs": [],
......
{ {
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Customizing a Transformer Encoder",
"private_outputs": true,
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "Bp8t2AI8i7uP" "id": "Bp8t2AI8i7uP"
}, },
"source": [ "source": [
...@@ -12,14 +26,10 @@ ...@@ -12,14 +26,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"cellView": "form", "cellView": "form",
"colab": {},
"colab_type": "code",
"id": "rxPj2Lsni9O4" "id": "rxPj2Lsni9O4"
}, },
"outputs": [],
"source": [ "source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n", "# you may not use this file except in compliance with the License.\n",
...@@ -32,12 +42,13 @@ ...@@ -32,12 +42,13 @@
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n", "# See the License for the specific language governing permissions and\n",
"# limitations under the License." "# limitations under the License."
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "6xS-9i5DrRvO" "id": "6xS-9i5DrRvO"
}, },
"source": [ "source": [
...@@ -47,30 +58,28 @@ ...@@ -47,30 +58,28 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "Mwb9uw1cDXsa" "id": "Mwb9uw1cDXsa"
}, },
"source": [ "source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
" \u003ctd\u003e\n", " <td>\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/official_models/nlp/customize_encoder\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", " <a target=\"_blank\" href=\"https://www.tensorflow.org/official_models/nlp/customize_encoder\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
" \u003c/td\u003e\n", " </td>\n",
" \u003ctd\u003e\n", " <td>\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/colab/nlp/customize_encoder.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", " <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/colab/nlp/customize_encoder.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" \u003c/td\u003e\n", " </td>\n",
" \u003ctd\u003e\n", " <td>\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/colab/nlp/customize_encoder.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", " <a target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/colab/nlp/customize_encoder.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
" \u003c/td\u003e\n", " </td>\n",
" \u003ctd\u003e\n", " <td>\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/official/colab/nlp/customize_encoder.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", " <a href=\"https://storage.googleapis.com/tensorflow_docs/models/official/colab/nlp/customize_encoder.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
" \u003c/td\u003e\n", " </td>\n",
"\u003c/table\u003e" "</table>"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "iLrcV4IyrcGX" "id": "iLrcV4IyrcGX"
}, },
"source": [ "source": [
...@@ -84,7 +93,6 @@ ...@@ -84,7 +93,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "YYxdyoWgsl8t" "id": "YYxdyoWgsl8t"
}, },
"source": [ "source": [
...@@ -94,7 +102,6 @@ ...@@ -94,7 +102,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "fEJSFutUsn_h" "id": "fEJSFutUsn_h"
}, },
"source": [ "source": [
...@@ -107,21 +114,18 @@ ...@@ -107,21 +114,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "thsKZDjhswhR" "id": "thsKZDjhswhR"
}, },
"outputs": [],
"source": [ "source": [
"!pip install -q tf-models-official==2.3.0" "!pip install -q tf-models-official==2.4.0"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "hpf7JPCVsqtv" "id": "hpf7JPCVsqtv"
}, },
"source": [ "source": [
...@@ -130,13 +134,9 @@ ...@@ -130,13 +134,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "my4dp-RMssQe" "id": "my4dp-RMssQe"
}, },
"outputs": [],
"source": [ "source": [
"import numpy as np\n", "import numpy as np\n",
"import tensorflow as tf\n", "import tensorflow as tf\n",
...@@ -144,12 +144,13 @@ ...@@ -144,12 +144,13 @@
"from official.modeling import activations\n", "from official.modeling import activations\n",
"from official.nlp import modeling\n", "from official.nlp import modeling\n",
"from official.nlp.modeling import layers, losses, models, networks" "from official.nlp.modeling import layers, losses, models, networks"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "vjDmVsFfs85n" "id": "vjDmVsFfs85n"
}, },
"source": [ "source": [
...@@ -160,13 +161,9 @@ ...@@ -160,13 +161,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "Oav8sbgstWc-" "id": "Oav8sbgstWc-"
}, },
"outputs": [],
"source": [ "source": [
"cfg = {\n", "cfg = {\n",
" \"vocab_size\": 100,\n", " \"vocab_size\": 100,\n",
...@@ -177,22 +174,23 @@ ...@@ -177,22 +174,23 @@
" \"activation\": activations.gelu,\n", " \"activation\": activations.gelu,\n",
" \"dropout_rate\": 0.1,\n", " \"dropout_rate\": 0.1,\n",
" \"attention_dropout_rate\": 0.1,\n", " \"attention_dropout_rate\": 0.1,\n",
" \"sequence_length\": 16,\n", " \"max_sequence_length\": 16,\n",
" \"type_vocab_size\": 2,\n", " \"type_vocab_size\": 2,\n",
" \"initializer\": tf.keras.initializers.TruncatedNormal(stddev=0.02),\n", " \"initializer\": tf.keras.initializers.TruncatedNormal(stddev=0.02),\n",
"}\n", "}\n",
"bert_encoder = modeling.networks.TransformerEncoder(**cfg)\n", "bert_encoder = modeling.networks.BertEncoder(**cfg)\n",
"\n", "\n",
"def build_classifier(bert_encoder):\n", "def build_classifier(bert_encoder):\n",
" return modeling.models.BertClassifier(bert_encoder, num_classes=2)\n", " return modeling.models.BertClassifier(bert_encoder, num_classes=2)\n",
"\n", "\n",
"canonical_classifier_model = build_classifier(bert_encoder)" "canonical_classifier_model = build_classifier(bert_encoder)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "Qe2UWI6_tsHo" "id": "Qe2UWI6_tsHo"
}, },
"source": [ "source": [
...@@ -203,31 +201,28 @@ ...@@ -203,31 +201,28 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "csED2d-Yt5h6" "id": "csED2d-Yt5h6"
}, },
"outputs": [],
"source": [ "source": [
"def predict(model):\n", "def predict(model):\n",
" batch_size = 3\n", " batch_size = 3\n",
" np.random.seed(0)\n", " np.random.seed(0)\n",
" word_ids = np.random.randint(\n", " word_ids = np.random.randint(\n",
" cfg[\"vocab_size\"], size=(batch_size, cfg[\"sequence_length\"]))\n", " cfg[\"vocab_size\"], size=(batch_size, cfg[\"max_sequence_length\"]))\n",
" mask = np.random.randint(2, size=(batch_size, cfg[\"sequence_length\"]))\n", " mask = np.random.randint(2, size=(batch_size, cfg[\"max_sequence_length\"]))\n",
" type_ids = np.random.randint(\n", " type_ids = np.random.randint(\n",
" cfg[\"type_vocab_size\"], size=(batch_size, cfg[\"sequence_length\"]))\n", " cfg[\"type_vocab_size\"], size=(batch_size, cfg[\"max_sequence_length\"]))\n",
" print(model([word_ids, mask, type_ids], training=False))\n", " print(model([word_ids, mask, type_ids], training=False))\n",
"\n", "\n",
"predict(canonical_classifier_model)" "predict(canonical_classifier_model)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "PzKStEK9t_Pb" "id": "PzKStEK9t_Pb"
}, },
"source": [ "source": [
...@@ -239,7 +234,6 @@ ...@@ -239,7 +234,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "rmwQfhj6fmKz" "id": "rmwQfhj6fmKz"
}, },
"source": [ "source": [
...@@ -250,7 +244,6 @@ ...@@ -250,7 +244,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "xsMgEVHAui11" "id": "xsMgEVHAui11"
}, },
"source": [ "source": [
...@@ -263,26 +256,21 @@ ...@@ -263,26 +256,21 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "-JBabpa2AOz8" "id": "-JBabpa2AOz8"
}, },
"source": [ "source": [
"#### Without Customization\n", "#### Without Customization\n",
"\n", "\n",
"Without any customization, `EncoderScaffold` behaves the same the canonical `TransformerEncoder`.\n", "Without any customization, `EncoderScaffold` behaves the same the canonical `BertEncoder`.\n",
"\n", "\n",
"As shown in the following example, `EncoderScaffold` can load `TransformerEncoder`'s weights and output the same values:" "As shown in the following example, `EncoderScaffold` can load `BertEncoder`'s weights and output the same values:"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "ktNzKuVByZQf" "id": "ktNzKuVByZQf"
}, },
"outputs": [],
"source": [ "source": [
"default_hidden_cfg = dict(\n", "default_hidden_cfg = dict(\n",
" num_attention_heads=cfg[\"num_attention_heads\"],\n", " num_attention_heads=cfg[\"num_attention_heads\"],\n",
...@@ -296,10 +284,9 @@ ...@@ -296,10 +284,9 @@
" vocab_size=cfg[\"vocab_size\"],\n", " vocab_size=cfg[\"vocab_size\"],\n",
" type_vocab_size=cfg[\"type_vocab_size\"],\n", " type_vocab_size=cfg[\"type_vocab_size\"],\n",
" hidden_size=cfg[\"hidden_size\"],\n", " hidden_size=cfg[\"hidden_size\"],\n",
" seq_length=cfg[\"sequence_length\"],\n",
" initializer=tf.keras.initializers.TruncatedNormal(0.02),\n", " initializer=tf.keras.initializers.TruncatedNormal(0.02),\n",
" dropout_rate=cfg[\"dropout_rate\"],\n", " dropout_rate=cfg[\"dropout_rate\"],\n",
" max_seq_length=cfg[\"sequence_length\"],\n", " max_seq_length=cfg[\"max_sequence_length\"]\n",
")\n", ")\n",
"default_kwargs = dict(\n", "default_kwargs = dict(\n",
" hidden_cfg=default_hidden_cfg,\n", " hidden_cfg=default_hidden_cfg,\n",
...@@ -309,17 +296,19 @@ ...@@ -309,17 +296,19 @@
" return_all_layer_outputs=True,\n", " return_all_layer_outputs=True,\n",
" pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(0.02),\n", " pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(0.02),\n",
")\n", ")\n",
"\n",
"encoder_scaffold = modeling.networks.EncoderScaffold(**default_kwargs)\n", "encoder_scaffold = modeling.networks.EncoderScaffold(**default_kwargs)\n",
"classifier_model_from_encoder_scaffold = build_classifier(encoder_scaffold)\n", "classifier_model_from_encoder_scaffold = build_classifier(encoder_scaffold)\n",
"classifier_model_from_encoder_scaffold.set_weights(\n", "classifier_model_from_encoder_scaffold.set_weights(\n",
" canonical_classifier_model.get_weights())\n", " canonical_classifier_model.get_weights())\n",
"predict(classifier_model_from_encoder_scaffold)" "predict(classifier_model_from_encoder_scaffold)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "sMaUmLyIuwcs" "id": "sMaUmLyIuwcs"
}, },
"source": [ "source": [
...@@ -332,18 +321,14 @@ ...@@ -332,18 +321,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "LTinnaG6vcsw" "id": "LTinnaG6vcsw"
}, },
"outputs": [],
"source": [ "source": [
"word_ids = tf.keras.layers.Input(\n", "word_ids = tf.keras.layers.Input(\n",
" shape=(cfg['sequence_length'],), dtype=tf.int32, name=\"input_word_ids\")\n", " shape=(cfg['max_sequence_length'],), dtype=tf.int32, name=\"input_word_ids\")\n",
"mask = tf.keras.layers.Input(\n", "mask = tf.keras.layers.Input(\n",
" shape=(cfg['sequence_length'],), dtype=tf.int32, name=\"input_mask\")\n", " shape=(cfg['max_sequence_length'],), dtype=tf.int32, name=\"input_mask\")\n",
"embedding_layer = modeling.layers.OnDeviceEmbedding(\n", "embedding_layer = modeling.layers.OnDeviceEmbedding(\n",
" vocab_size=cfg['vocab_size'],\n", " vocab_size=cfg['vocab_size'],\n",
" embedding_width=cfg['hidden_size'],\n", " embedding_width=cfg['hidden_size'],\n",
...@@ -353,12 +338,13 @@ ...@@ -353,12 +338,13 @@
"attention_mask = layers.SelfAttentionMask()([word_embeddings, mask])\n", "attention_mask = layers.SelfAttentionMask()([word_embeddings, mask])\n",
"new_embedding_network = tf.keras.Model([word_ids, mask],\n", "new_embedding_network = tf.keras.Model([word_ids, mask],\n",
" [word_embeddings, attention_mask])" " [word_embeddings, attention_mask])"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "HN7_yu-6O3qI" "id": "HN7_yu-6O3qI"
}, },
"source": [ "source": [
...@@ -368,21 +354,18 @@ ...@@ -368,21 +354,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "fO9zKFE4OpHp" "id": "fO9zKFE4OpHp"
}, },
"outputs": [],
"source": [ "source": [
"tf.keras.utils.plot_model(new_embedding_network, show_shapes=True, dpi=48)" "tf.keras.utils.plot_model(new_embedding_network, show_shapes=True, dpi=48)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "9cOaGQHLv12W" "id": "9cOaGQHLv12W"
}, },
"source": [ "source": [
...@@ -391,13 +374,9 @@ ...@@ -391,13 +374,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "mtFDMNf2vIl9" "id": "mtFDMNf2vIl9"
}, },
"outputs": [],
"source": [ "source": [
"kwargs = dict(default_kwargs)\n", "kwargs = dict(default_kwargs)\n",
"\n", "\n",
...@@ -412,12 +391,13 @@ ...@@ -412,12 +391,13 @@
"\n", "\n",
"# Assert that there are only two inputs.\n", "# Assert that there are only two inputs.\n",
"assert len(classifier_model.inputs) == 2" "assert len(classifier_model.inputs) == 2"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "Z73ZQDtmwg9K" "id": "Z73ZQDtmwg9K"
}, },
"source": [ "source": [
...@@ -432,13 +412,9 @@ ...@@ -432,13 +412,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "uAIarLZgw6pA" "id": "uAIarLZgw6pA"
}, },
"outputs": [],
"source": [ "source": [
"kwargs = dict(default_kwargs)\n", "kwargs = dict(default_kwargs)\n",
"\n", "\n",
...@@ -452,12 +428,13 @@ ...@@ -452,12 +428,13 @@
"\n", "\n",
"# Assert that the variable `rezero_alpha` from ReZeroTransformer exists.\n", "# Assert that the variable `rezero_alpha` from ReZeroTransformer exists.\n",
"assert 'rezero_alpha' in ''.join([x.name for x in classifier_model.trainable_weights])" "assert 'rezero_alpha' in ''.join([x.name for x in classifier_model.trainable_weights])"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "6PMHFdvnxvR0" "id": "6PMHFdvnxvR0"
}, },
"source": [ "source": [
...@@ -470,7 +447,6 @@ ...@@ -470,7 +447,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "D6FejlgwyAy_" "id": "D6FejlgwyAy_"
}, },
"source": [ "source": [
...@@ -485,13 +461,9 @@ ...@@ -485,13 +461,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "nFrSMrZuyNeQ" "id": "nFrSMrZuyNeQ"
}, },
"outputs": [],
"source": [ "source": [
"# Use TalkingHeadsAttention\n", "# Use TalkingHeadsAttention\n",
"hidden_cfg = dict(default_hidden_cfg)\n", "hidden_cfg = dict(default_hidden_cfg)\n",
...@@ -508,12 +480,13 @@ ...@@ -508,12 +480,13 @@
"\n", "\n",
"# Assert that the variable `pre_softmax_weight` from TalkingHeadsAttention exists.\n", "# Assert that the variable `pre_softmax_weight` from TalkingHeadsAttention exists.\n",
"assert 'pre_softmax_weight' in ''.join([x.name for x in classifier_model.trainable_weights])" "assert 'pre_softmax_weight' in ''.join([x.name for x in classifier_model.trainable_weights])"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "kuEJcTyByVvI" "id": "kuEJcTyByVvI"
}, },
"source": [ "source": [
...@@ -528,13 +501,9 @@ ...@@ -528,13 +501,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "XAbKy_l4y_-i" "id": "XAbKy_l4y_-i"
}, },
"outputs": [],
"source": [ "source": [
"# Use TalkingHeadsAttention\n", "# Use TalkingHeadsAttention\n",
"hidden_cfg = dict(default_hidden_cfg)\n", "hidden_cfg = dict(default_hidden_cfg)\n",
...@@ -551,12 +520,13 @@ ...@@ -551,12 +520,13 @@
"\n", "\n",
"# Assert that the variable `gate` from GatedFeedforward exists.\n", "# Assert that the variable `gate` from GatedFeedforward exists.\n",
"assert 'gate' in ''.join([x.name for x in classifier_model.trainable_weights])" "assert 'gate' in ''.join([x.name for x in classifier_model.trainable_weights])"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "a_8NWUhkzeAq" "id": "a_8NWUhkzeAq"
}, },
"source": [ "source": [
...@@ -564,29 +534,26 @@ ...@@ -564,29 +534,26 @@
"\n", "\n",
"Finally, you could also build a new encoder using building blocks in the modeling library.\n", "Finally, you could also build a new encoder using building blocks in the modeling library.\n",
"\n", "\n",
"See [AlbertTransformerEncoder](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/albert_transformer_encoder.py) as an example:\n" "See [AlbertEncoder](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/albert_encoder.py) as an example:\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "xsiA3RzUzmUM" "id": "xsiA3RzUzmUM"
}, },
"outputs": [],
"source": [ "source": [
"albert_encoder = modeling.networks.AlbertTransformerEncoder(**cfg)\n", "albert_encoder = modeling.networks.AlbertEncoder(**cfg)\n",
"classifier_model = build_classifier(albert_encoder)\n", "classifier_model = build_classifier(albert_encoder)\n",
"# ... Train the model ...\n", "# ... Train the model ...\n",
"predict(classifier_model)" "predict(classifier_model)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "MeidDfhlHKSO" "id": "MeidDfhlHKSO"
}, },
"source": [ "source": [
...@@ -595,31 +562,14 @@ ...@@ -595,31 +562,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "Uv_juT22HERW" "id": "Uv_juT22HERW"
}, },
"outputs": [],
"source": [ "source": [
"tf.keras.utils.plot_model(albert_encoder, show_shapes=True, dpi=48)" "tf.keras.utils.plot_model(albert_encoder, show_shapes=True, dpi=48)"
] ],
} "execution_count": null,
], "outputs": []
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Customizing a Transformer Encoder",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
} }
}, ]
"nbformat": 4, }
"nbformat_minor": 0 \ No newline at end of file
}
{ {
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Introduction to the TensorFlow Models NLP library",
"private_outputs": true,
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "80xnUmoI7fBX" "id": "80xnUmoI7fBX"
}, },
"source": [ "source": [
...@@ -12,14 +26,10 @@ ...@@ -12,14 +26,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"cellView": "form", "cellView": "form",
"colab": {},
"colab_type": "code",
"id": "8nvTnfs6Q692" "id": "8nvTnfs6Q692"
}, },
"outputs": [],
"source": [ "source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n", "# you may not use this file except in compliance with the License.\n",
...@@ -32,12 +42,13 @@ ...@@ -32,12 +42,13 @@
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n", "# See the License for the specific language governing permissions and\n",
"# limitations under the License." "# limitations under the License."
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "WmfcMK5P5C1G" "id": "WmfcMK5P5C1G"
}, },
"source": [ "source": [
...@@ -47,30 +58,28 @@ ...@@ -47,30 +58,28 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "cH-oJ8R6AHMK" "id": "cH-oJ8R6AHMK"
}, },
"source": [ "source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
" \u003ctd\u003e\n", " <td>\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/official_models/nlp/nlp_modeling_library_intro\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", " <a target=\"_blank\" href=\"https://www.tensorflow.org/official_models/nlp/nlp_modeling_library_intro\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
" \u003c/td\u003e\n", " </td>\n",
" \u003ctd\u003e\n", " <td>\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", " <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" \u003c/td\u003e\n", " </td>\n",
" \u003ctd\u003e\n", " <td>\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", " <a target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
" \u003c/td\u003e\n", " </td>\n",
" \u003ctd\u003e\n", " <td>\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/models/official/colab/nlp/nlp_modeling_library_intro.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", " <a href=\"https://storage.googleapis.com/tensorflow_docs/models/official/colab/nlp/nlp_modeling_library_intro.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
" \u003c/td\u003e\n", " </td>\n",
"\u003c/table\u003e" "</table>"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "0H_EFIhq4-MJ" "id": "0H_EFIhq4-MJ"
}, },
"source": [ "source": [
...@@ -82,7 +91,6 @@ ...@@ -82,7 +91,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "2N97-dps_nUk" "id": "2N97-dps_nUk"
}, },
"source": [ "source": [
...@@ -92,7 +100,6 @@ ...@@ -92,7 +100,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "459ygAVl_rg0" "id": "459ygAVl_rg0"
}, },
"source": [ "source": [
...@@ -105,21 +112,18 @@ ...@@ -105,21 +112,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "Y-qGkdh6_sZc" "id": "Y-qGkdh6_sZc"
}, },
"outputs": [],
"source": [ "source": [
"!pip install -q tf-models-official==2.3.0" "!pip install -q tf-models-official==2.4.0"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "e4huSSwyAG_5" "id": "e4huSSwyAG_5"
}, },
"source": [ "source": [
...@@ -128,25 +132,22 @@ ...@@ -128,25 +132,22 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "jqYXqtjBAJd9" "id": "jqYXqtjBAJd9"
}, },
"outputs": [],
"source": [ "source": [
"import numpy as np\n", "import numpy as np\n",
"import tensorflow as tf\n", "import tensorflow as tf\n",
"\n", "\n",
"from official.nlp import modeling\n", "from official.nlp import modeling\n",
"from official.nlp.modeling import layers, losses, models, networks" "from official.nlp.modeling import layers, losses, models, networks"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "djBQWjvy-60Y" "id": "djBQWjvy-60Y"
}, },
"source": [ "source": [
...@@ -160,38 +161,34 @@ ...@@ -160,38 +161,34 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "MKuHVlsCHmiq" "id": "MKuHVlsCHmiq"
}, },
"source": [ "source": [
"### Build a `BertPretrainer` model wrapping `TransformerEncoder`\n", "### Build a `BertPretrainer` model wrapping `BertEncoder`\n",
"\n", "\n",
"The [TransformerEncoder](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/transformer_encoder.py) implements the Transformer-based encoder as described in [BERT paper](https://arxiv.org/abs/1810.04805). It includes the embedding lookups and transformer layers, but not the masked language model or classification task networks.\n", "The [BertEncoder](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/networks/bert_encoder.py) implements the Transformer-based encoder as described in [BERT paper](https://arxiv.org/abs/1810.04805). It includes the embedding lookups and transformer layers, but not the masked language model or classification task networks.\n",
"\n", "\n",
"The [BertPretrainer](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_pretrainer.py) allows a user to pass in a transformer stack, and instantiates the masked language model and classification networks that are used to create the training objectives." "The [BertPretrainer](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_pretrainer.py) allows a user to pass in a transformer stack, and instantiates the masked language model and classification networks that are used to create the training objectives."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "EXkcXz-9BwB3" "id": "EXkcXz-9BwB3"
}, },
"outputs": [],
"source": [ "source": [
"# Build a small transformer network.\n", "# Build a small transformer network.\n",
"vocab_size = 100\n", "vocab_size = 100\n",
"sequence_length = 16\n", "sequence_length = 16\n",
"network = modeling.networks.TransformerEncoder(\n", "network = modeling.networks.BertEncoder(\n",
" vocab_size=vocab_size, num_layers=2, sequence_length=16)" " vocab_size=vocab_size, num_layers=2, sequence_length=16)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "0NH5irV5KTMS" "id": "0NH5irV5KTMS"
}, },
"source": [ "source": [
...@@ -202,37 +199,32 @@ ...@@ -202,37 +199,32 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "lZNoZkBrIoff" "id": "lZNoZkBrIoff"
}, },
"outputs": [],
"source": [ "source": [
"tf.keras.utils.plot_model(network, show_shapes=True, dpi=48)" "tf.keras.utils.plot_model(network, show_shapes=True, dpi=48)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "o7eFOZXiIl-b" "id": "o7eFOZXiIl-b"
}, },
"outputs": [],
"source": [ "source": [
"# Create a BERT pretrainer with the created network.\n", "# Create a BERT pretrainer with the created network.\n",
"num_token_predictions = 8\n", "num_token_predictions = 8\n",
"bert_pretrainer = modeling.models.BertPretrainer(\n", "bert_pretrainer = modeling.models.BertPretrainer(\n",
" network, num_classes=2, num_token_predictions=num_token_predictions, output='predictions')" " network, num_classes=2, num_token_predictions=num_token_predictions, output='predictions')"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "d5h5HT7gNHx_" "id": "d5h5HT7gNHx_"
}, },
"source": [ "source": [
...@@ -241,26 +233,20 @@ ...@@ -241,26 +233,20 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "2tcNfm03IBF7" "id": "2tcNfm03IBF7"
}, },
"outputs": [],
"source": [ "source": [
"tf.keras.utils.plot_model(bert_pretrainer, show_shapes=True, dpi=48)" "tf.keras.utils.plot_model(bert_pretrainer, show_shapes=True, dpi=48)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "F2oHrXGUIS0M" "id": "F2oHrXGUIS0M"
}, },
"outputs": [],
"source": [ "source": [
"# We can feed some dummy data to get masked language model and sentence output.\n", "# We can feed some dummy data to get masked language model and sentence output.\n",
"batch_size = 2\n", "batch_size = 2\n",
...@@ -275,12 +261,13 @@ ...@@ -275,12 +261,13 @@
"sentence_output = outputs[\"classification\"]\n", "sentence_output = outputs[\"classification\"]\n",
"print(lm_output)\n", "print(lm_output)\n",
"print(sentence_output)" "print(sentence_output)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "bnx3UCHniCS5" "id": "bnx3UCHniCS5"
}, },
"source": [ "source": [
...@@ -290,13 +277,9 @@ ...@@ -290,13 +277,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "k30H4Q86f52x" "id": "k30H4Q86f52x"
}, },
"outputs": [],
"source": [ "source": [
"masked_lm_ids_data = np.random.randint(vocab_size, size=(batch_size, num_token_predictions))\n", "masked_lm_ids_data = np.random.randint(vocab_size, size=(batch_size, num_token_predictions))\n",
"masked_lm_weights_data = np.random.randint(2, size=(batch_size, num_token_predictions))\n", "masked_lm_weights_data = np.random.randint(2, size=(batch_size, num_token_predictions))\n",
...@@ -311,12 +294,13 @@ ...@@ -311,12 +294,13 @@
" predictions=sentence_output)\n", " predictions=sentence_output)\n",
"loss = mlm_loss + sentence_loss\n", "loss = mlm_loss + sentence_loss\n",
"print(loss)" "print(loss)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "wrmSs8GjHxVw" "id": "wrmSs8GjHxVw"
}, },
"source": [ "source": [
...@@ -328,7 +312,6 @@ ...@@ -328,7 +312,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "k8cQVFvBCV4s" "id": "k8cQVFvBCV4s"
}, },
"source": [ "source": [
...@@ -342,38 +325,34 @@ ...@@ -342,38 +325,34 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "xrLLEWpfknUW" "id": "xrLLEWpfknUW"
}, },
"source": [ "source": [
"### Build a BertSpanLabeler wrapping TransformerEncoder\n", "### Build a BertSpanLabeler wrapping BertEncoder\n",
"\n", "\n",
"[BertSpanLabeler](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_span_labeler.py) implements a simple single-span start-end predictor (that is, a model that predicts two values: a start token index and an end token index), suitable for SQuAD-style tasks.\n", "[BertSpanLabeler](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_span_labeler.py) implements a simple single-span start-end predictor (that is, a model that predicts two values: a start token index and an end token index), suitable for SQuAD-style tasks.\n",
"\n", "\n",
"Note that `BertSpanLabeler` wraps a `TransformerEncoder`, the weights of which can be restored from the above pretraining model.\n" "Note that `BertSpanLabeler` wraps a `BertEncoder`, the weights of which can be restored from the above pretraining model.\n"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "B941M4iUCejO" "id": "B941M4iUCejO"
}, },
"outputs": [],
"source": [ "source": [
"network = modeling.networks.TransformerEncoder(\n", "network = modeling.networks.BertEncoder(\n",
" vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)\n", " vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)\n",
"\n", "\n",
"# Create a BERT trainer with the created network.\n", "# Create a BERT trainer with the created network.\n",
"bert_span_labeler = modeling.models.BertSpanLabeler(network)" "bert_span_labeler = modeling.models.BertSpanLabeler(network)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "QpB9pgj4PpMg" "id": "QpB9pgj4PpMg"
}, },
"source": [ "source": [
...@@ -382,26 +361,20 @@ ...@@ -382,26 +361,20 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "RbqRNJCLJu4H" "id": "RbqRNJCLJu4H"
}, },
"outputs": [],
"source": [ "source": [
"tf.keras.utils.plot_model(bert_span_labeler, show_shapes=True, dpi=48)" "tf.keras.utils.plot_model(bert_span_labeler, show_shapes=True, dpi=48)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "fUf1vRxZJwio" "id": "fUf1vRxZJwio"
}, },
"outputs": [],
"source": [ "source": [
"# Create a set of 2-dimensional data tensors to feed into the model.\n", "# Create a set of 2-dimensional data tensors to feed into the model.\n",
"word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n", "word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
...@@ -412,12 +385,13 @@ ...@@ -412,12 +385,13 @@
"start_logits, end_logits = bert_span_labeler([word_id_data, mask_data, type_id_data])\n", "start_logits, end_logits = bert_span_labeler([word_id_data, mask_data, type_id_data])\n",
"print(start_logits)\n", "print(start_logits)\n",
"print(end_logits)" "print(end_logits)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "WqhgQaN1lt-G" "id": "WqhgQaN1lt-G"
}, },
"source": [ "source": [
...@@ -427,13 +401,9 @@ ...@@ -427,13 +401,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "waqs6azNl3Nn" "id": "waqs6azNl3Nn"
}, },
"outputs": [],
"source": [ "source": [
"start_positions = np.random.randint(sequence_length, size=(batch_size))\n", "start_positions = np.random.randint(sequence_length, size=(batch_size))\n",
"end_positions = np.random.randint(sequence_length, size=(batch_size))\n", "end_positions = np.random.randint(sequence_length, size=(batch_size))\n",
...@@ -445,12 +415,13 @@ ...@@ -445,12 +415,13 @@
"\n", "\n",
"total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2\n", "total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2\n",
"print(total_loss)" "print(total_loss)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "Zdf03YtZmd_d" "id": "Zdf03YtZmd_d"
}, },
"source": [ "source": [
...@@ -460,7 +431,6 @@ ...@@ -460,7 +431,6 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "0A1XnGSTChg9" "id": "0A1XnGSTChg9"
}, },
"source": [ "source": [
...@@ -472,38 +442,34 @@ ...@@ -472,38 +442,34 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "MSK8OpZgnQa9" "id": "MSK8OpZgnQa9"
}, },
"source": [ "source": [
"### Build a BertClassifier model wrapping TransformerEncoder\n", "### Build a BertClassifier model wrapping BertEncoder\n",
"\n", "\n",
"[BertClassifier](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_classifier.py) implements a [CLS] token classification model containing a single classification head." "[BertClassifier](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_classifier.py) implements a [CLS] token classification model containing a single classification head."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "cXXCsffkCphk" "id": "cXXCsffkCphk"
}, },
"outputs": [],
"source": [ "source": [
"network = modeling.networks.TransformerEncoder(\n", "network = modeling.networks.BertEncoder(\n",
" vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)\n", " vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)\n",
"\n", "\n",
"# Create a BERT trainer with the created network.\n", "# Create a BERT trainer with the created network.\n",
"num_classes = 2\n", "num_classes = 2\n",
"bert_classifier = modeling.models.BertClassifier(\n", "bert_classifier = modeling.models.BertClassifier(\n",
" network, num_classes=num_classes)" " network, num_classes=num_classes)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "8tZKueKYP4bB" "id": "8tZKueKYP4bB"
}, },
"source": [ "source": [
...@@ -512,26 +478,20 @@ ...@@ -512,26 +478,20 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "snlutm9ZJgEZ" "id": "snlutm9ZJgEZ"
}, },
"outputs": [],
"source": [ "source": [
"tf.keras.utils.plot_model(bert_classifier, show_shapes=True, dpi=48)" "tf.keras.utils.plot_model(bert_classifier, show_shapes=True, dpi=48)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "yyHPHsqBJkCz" "id": "yyHPHsqBJkCz"
}, },
"outputs": [],
"source": [ "source": [
"# Create a set of 2-dimensional data tensors to feed into the model.\n", "# Create a set of 2-dimensional data tensors to feed into the model.\n",
"word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n", "word_id_data = np.random.randint(vocab_size, size=(batch_size, sequence_length))\n",
...@@ -541,12 +501,13 @@ ...@@ -541,12 +501,13 @@
"# Feed the data to the model.\n", "# Feed the data to the model.\n",
"logits = bert_classifier([word_id_data, mask_data, type_id_data])\n", "logits = bert_classifier([word_id_data, mask_data, type_id_data])\n",
"print(logits)" "print(logits)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "w--a2mg4nzKm" "id": "w--a2mg4nzKm"
}, },
"source": [ "source": [
...@@ -557,45 +518,27 @@ ...@@ -557,45 +518,27 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null,
"metadata": { "metadata": {
"colab": {},
"colab_type": "code",
"id": "9X0S1DoFn_5Q" "id": "9X0S1DoFn_5Q"
}, },
"outputs": [],
"source": [ "source": [
"labels = np.random.randint(num_classes, size=(batch_size))\n", "labels = np.random.randint(num_classes, size=(batch_size))\n",
"\n", "\n",
"loss = modeling.losses.weighted_sparse_categorical_crossentropy_loss(\n", "loss = modeling.losses.weighted_sparse_categorical_crossentropy_loss(\n",
" labels=labels, predictions=tf.nn.log_softmax(logits, axis=-1))\n", " labels=labels, predictions=tf.nn.log_softmax(logits, axis=-1))\n",
"print(loss)" "print(loss)"
] ],
"execution_count": null,
"outputs": []
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"colab_type": "text",
"id": "mzBqOylZo3og" "id": "mzBqOylZo3og"
}, },
"source": [ "source": [
"With the `loss`, you can optimize the model. Please see [run_classifier.py](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_classifier.py) or the colab [fine_tuning_bert.ipynb](https://github.com/tensorflow/models/blob/master/official/colab/fine_tuning_bert.ipynb) for the full example." "With the `loss`, you can optimize the model. Please see [run_classifier.py](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_classifier.py) or the colab [fine_tuning_bert.ipynb](https://github.com/tensorflow/models/blob/master/official/colab/fine_tuning_bert.ipynb) for the full example."
] ]
} }
], ]
"metadata": { }
"colab": { \ No newline at end of file
"collapsed_sections": [],
"name": "Introduction to the TensorFlow Models NLP library",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
...@@ -127,6 +127,15 @@ def get_distribution_strategy(distribution_strategy="mirrored", ...@@ -127,6 +127,15 @@ def get_distribution_strategy(distribution_strategy="mirrored",
if num_gpus < 0: if num_gpus < 0:
raise ValueError("`num_gpus` can not be negative.") raise ValueError("`num_gpus` can not be negative.")
if not isinstance(distribution_strategy, str):
msg = ("distribution_strategy must be a string but got: %s." %
(distribution_strategy,))
if distribution_strategy == False: # pylint: disable=singleton-comparison,g-explicit-bool-comparison
msg += (" If you meant to pass the string 'off', make sure you add "
"quotes around 'off' so that yaml interprets it as a string "
"instead of a bool.")
raise ValueError(msg)
distribution_strategy = distribution_strategy.lower() distribution_strategy = distribution_strategy.lower()
if distribution_strategy == "off": if distribution_strategy == "off":
if num_gpus > 1: if num_gpus > 1:
......
...@@ -41,6 +41,19 @@ class GetDistributionStrategyTest(tf.test.TestCase): ...@@ -41,6 +41,19 @@ class GetDistributionStrategyTest(tf.test.TestCase):
for device in ds.extended.worker_devices: for device in ds.extended.worker_devices:
self.assertIn('GPU', device) self.assertIn('GPU', device)
def test_no_strategy(self):
ds = distribute_utils.get_distribution_strategy('off')
self.assertIsNone(ds)
def test_invalid_strategy(self):
with self.assertRaisesRegexp(
ValueError,
'distribution_strategy must be a string but got: False. If'):
distribute_utils.get_distribution_strategy(False)
with self.assertRaisesRegexp(
ValueError, 'distribution_strategy must be a string but got: 1'):
distribute_utils.get_distribution_strategy(1)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""A common dataset reader.""" """A common dataset reader."""
import random import random
from typing import Any, Callable, Optional from typing import Any, Callable, List, Optional
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -27,6 +27,13 @@ def _get_random_integer(): ...@@ -27,6 +27,13 @@ def _get_random_integer():
return random.randint(0, (1 << 31) - 1) return random.randint(0, (1 << 31) - 1)
def _maybe_map_fn(dataset: tf.data.Dataset,
fn: Optional[Callable[..., Any]] = None) -> tf.data.Dataset:
"""Calls dataset.map if a valid function is passed in."""
return dataset if fn is None else dataset.map(
fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
class InputReader: class InputReader:
"""Input reader that returns a tf.data.Dataset instance.""" """Input reader that returns a tf.data.Dataset instance."""
...@@ -74,38 +81,7 @@ class InputReader: ...@@ -74,38 +81,7 @@ class InputReader:
self._tfds_builder = None self._tfds_builder = None
self._matched_files = [] self._matched_files = []
if params.input_path: if params.input_path:
# Read dataset from files. self._matched_files = self._match_files(params.input_path)
usage = ('`input_path` should be either (1) a str indicating a file '
'path/pattern, or (2) a str indicating multiple file '
'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
'"a,b,c", or (3) a list of str, each of which is a file '
'path/pattern or multiple file paths/patterns separated by '
'comma, but got: %s')
if isinstance(params.input_path, str):
input_path_list = [params.input_path]
elif isinstance(params.input_path, (list, tuple)):
if any(not isinstance(x, str) for x in params.input_path):
raise ValueError(usage % params.input_path)
input_path_list = params.input_path
else:
raise ValueError(usage % params.input_path)
for input_path in input_path_list:
input_patterns = input_path.strip().split(',')
for input_pattern in input_patterns:
input_pattern = input_pattern.strip()
if not input_pattern:
continue
if '*' in input_pattern or '?' in input_pattern:
tmp_matched_files = tf.io.gfile.glob(input_pattern)
if not tmp_matched_files:
raise ValueError('%s does not match any files.' % input_pattern)
self._matched_files.extend(tmp_matched_files)
else:
self._matched_files.append(input_pattern)
if not self._matched_files:
raise ValueError('%s does not match any files.' % params.input_path)
else: else:
# Read dataset from TFDS. # Read dataset from TFDS.
if not params.tfds_split: if not params.tfds_split:
...@@ -135,7 +111,10 @@ class InputReader: ...@@ -135,7 +111,10 @@ class InputReader:
self._parser_fn = parser_fn self._parser_fn = parser_fn
self._transform_and_batch_fn = transform_and_batch_fn self._transform_and_batch_fn = transform_and_batch_fn
self._postprocess_fn = postprocess_fn self._postprocess_fn = postprocess_fn
self._seed = _get_random_integer() # When tf.data service is enabled, each data service worker should get
# different random seeds. Thus, we set `seed` to None.
self._seed = (None
if params.enable_tf_data_service else _get_random_integer())
self._enable_tf_data_service = ( self._enable_tf_data_service = (
params.enable_tf_data_service and params.tf_data_service_address) params.enable_tf_data_service and params.tf_data_service_address)
...@@ -148,15 +127,57 @@ class InputReader: ...@@ -148,15 +127,57 @@ class InputReader:
self._enable_round_robin_tf_data_service = params.get( self._enable_round_robin_tf_data_service = params.get(
'enable_round_robin_tf_data_service', False) 'enable_round_robin_tf_data_service', False)
def _match_files(self, input_path: str) -> List[str]:
"""Matches files from an input_path."""
matched_files = []
# Read dataset from files.
usage = ('`input_path` should be either (1) a str indicating a file '
'path/pattern, or (2) a str indicating multiple file '
'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
'"a,b,c", or (3) a list of str, each of which is a file '
'path/pattern or multiple file paths/patterns separated by '
'comma, but got: %s')
if isinstance(input_path, str):
input_path_list = [input_path]
elif isinstance(input_path, (list, tuple)):
if any(not isinstance(x, str) for x in input_path):
raise ValueError(usage % input_path)
input_path_list = input_path
else:
raise ValueError(usage % input_path)
for input_path in input_path_list:
input_patterns = input_path.strip().split(',')
for input_pattern in input_patterns:
input_pattern = input_pattern.strip()
if not input_pattern:
continue
if '*' in input_pattern or '?' in input_pattern:
tmp_matched_files = tf.io.gfile.glob(input_pattern)
if not tmp_matched_files:
raise ValueError('%s does not match any files.' % input_pattern)
matched_files.extend(tmp_matched_files)
else:
matched_files.append(input_pattern)
if not matched_files:
raise ValueError('%s does not match any files.' % input_path)
return matched_files
def _shard_files_then_read( def _shard_files_then_read(
self, input_context: Optional[tf.distribute.InputContext] = None): self,
matched_files: List[str],
dataset_fn,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Shards the data files and then sent a split to every worker to read.""" """Shards the data files and then sent a split to every worker to read."""
dataset = tf.data.Dataset.from_tensor_slices(self._matched_files) dataset = tf.data.Dataset.from_tensor_slices(matched_files)
# Shuffle and repeat at file level. # Shuffle and repeat at file level.
if self._is_training: if self._is_training:
dataset = dataset.shuffle( dataset = dataset.shuffle(
len(self._matched_files), len(matched_files),
seed=self._seed, seed=self._seed,
reshuffle_each_iteration=True) reshuffle_each_iteration=True)
...@@ -171,7 +192,7 @@ class InputReader: ...@@ -171,7 +192,7 @@ class InputReader:
dataset = dataset.repeat() dataset = dataset.repeat()
dataset = dataset.interleave( dataset = dataset.interleave(
map_func=self._dataset_fn, map_func=dataset_fn,
cycle_length=self._cycle_length, cycle_length=self._cycle_length,
block_length=self._block_length, block_length=self._block_length,
num_parallel_calls=(self._cycle_length if self._cycle_length else num_parallel_calls=(self._cycle_length if self._cycle_length else
...@@ -180,9 +201,13 @@ class InputReader: ...@@ -180,9 +201,13 @@ class InputReader:
return dataset return dataset
def _read_files_then_shard( def _read_files_then_shard(
self, input_context: Optional[tf.distribute.InputContext] = None): self,
matched_files: List[str],
dataset_fn,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Sends all data files to every worker and then shard by data.""" """Sends all data files to every worker and then shard by data."""
dataset = self._dataset_fn(self._matched_files) dataset = dataset_fn(matched_files)
# When `input_file` is a path to a single file or the number of files is # When `input_file` is a path to a single file or the number of files is
# less than the number of input pipelines, disable auto sharding # less than the number of input pipelines, disable auto sharding
...@@ -238,26 +263,35 @@ class InputReader: ...@@ -238,26 +263,35 @@ class InputReader:
raise ValueError('tfds_info is not available, because the dataset ' raise ValueError('tfds_info is not available, because the dataset '
'is not loaded from tfds.') 'is not loaded from tfds.')
def read( def _read_decode_and_parse_dataset(
self, self,
input_context: Optional[tf.distribute.InputContext] = None matched_files: List[str],
) -> tf.data.Dataset: dataset_fn,
"""Generates a tf.data.Dataset object.""" batch_size: int,
if self._tfds_builder: input_context: Optional[tf.distribute.InputContext] = None,
tfds_builder: bool = False) -> tf.data.Dataset:
"""Returns a tf.data.Dataset object after reading, decoding, and parsing."""
if tfds_builder:
dataset = self._read_tfds(input_context) dataset = self._read_tfds(input_context)
elif len(self._matched_files) > 1: elif len(self._matched_files) > 1:
if input_context and (len(self._matched_files) < if input_context and (len(matched_files) <
input_context.num_input_pipelines): input_context.num_input_pipelines):
logging.warn( logging.warn(
'The number of files %d is less than the number of input pipelines ' 'The number of files %d is less than the number of input pipelines '
'%d. We will send all input files to every worker. ' '%d. We will send all input files to every worker. '
'Please consider sharding your data into more files.', 'Please consider sharding your data into more files.',
len(self._matched_files), input_context.num_input_pipelines) len(matched_files), input_context.num_input_pipelines)
dataset = self._read_files_then_shard(input_context) dataset = self._read_files_then_shard(matched_files,
dataset_fn,
input_context)
else: else:
dataset = self._shard_files_then_read(input_context) dataset = self._shard_files_then_read(matched_files,
elif len(self._matched_files) == 1: dataset_fn,
dataset = self._read_files_then_shard(input_context) input_context)
elif len(matched_files) == 1:
dataset = self._read_files_then_shard(matched_files,
dataset_fn,
input_context)
else: else:
raise ValueError('It is unexpected that `tfds_builder` is None and ' raise ValueError('It is unexpected that `tfds_builder` is None and '
'there is also no `matched_files`.') 'there is also no `matched_files`.')
...@@ -268,25 +302,28 @@ class InputReader: ...@@ -268,25 +302,28 @@ class InputReader:
if self._is_training: if self._is_training:
dataset = dataset.shuffle(self._shuffle_buffer_size) dataset = dataset.shuffle(self._shuffle_buffer_size)
def maybe_map_fn(dataset, fn): dataset = _maybe_map_fn(dataset, self._decoder_fn)
return dataset if fn is None else dataset.map(
fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = maybe_map_fn(dataset, self._decoder_fn)
if self._sample_fn is not None: if self._sample_fn is not None:
dataset = dataset.apply(self._sample_fn) dataset = dataset.apply(self._sample_fn)
dataset = maybe_map_fn(dataset, self._parser_fn) dataset = _maybe_map_fn(dataset, self._parser_fn)
if self._transform_and_batch_fn is not None: if self._transform_and_batch_fn is not None:
dataset = self._transform_and_batch_fn(dataset, input_context) dataset = self._transform_and_batch_fn(dataset, input_context)
else: else:
per_replica_batch_size = input_context.get_per_replica_batch_size( per_replica_batch_size = input_context.get_per_replica_batch_size(
self._global_batch_size) if input_context else self._global_batch_size batch_size) if input_context else batch_size
dataset = dataset.batch( dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._drop_remainder) per_replica_batch_size, drop_remainder=self._drop_remainder
)
dataset = maybe_map_fn(dataset, self._postprocess_fn) return dataset
def _maybe_apply_data_service(
self,
dataset: tf.data.Dataset,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Potentially distributes a dataset."""
if self._enable_tf_data_service and input_context: if self._enable_tf_data_service and input_context:
if self._enable_round_robin_tf_data_service: if self._enable_round_robin_tf_data_service:
replicas_per_input_pipeline = input_context.num_replicas_in_sync // ( replicas_per_input_pipeline = input_context.num_replicas_in_sync // (
...@@ -316,6 +353,20 @@ class InputReader: ...@@ -316,6 +353,20 @@ class InputReader:
processing_mode='parallel_epochs', processing_mode='parallel_epochs',
service=self._tf_data_service_address, service=self._tf_data_service_address,
job_name=self._tf_data_service_job_name)) job_name=self._tf_data_service_job_name))
return dataset
def read(
self,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Generates a tf.data.Dataset object."""
dataset = self._read_decode_and_parse_dataset(self._matched_files,
self._dataset_fn,
self._global_batch_size,
input_context,
self._tfds_builder)
dataset = _maybe_map_fn(dataset, self._postprocess_fn)
dataset = self._maybe_apply_data_service(dataset, input_context)
if self._deterministic is not None: if self._deterministic is not None:
options = tf.data.Options() options = tf.data.Options()
......
...@@ -27,26 +27,7 @@ from official.core import config_definitions ...@@ -27,26 +27,7 @@ from official.core import config_definitions
from official.core import train_utils from official.core import train_utils
BestCheckpointExporter = train_utils.BestCheckpointExporter BestCheckpointExporter = train_utils.BestCheckpointExporter
maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter
def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig,
data_dir: str) -> Any:
"""Maybe create a BestCheckpointExporter object, according to the config."""
export_subdir = params.trainer.best_checkpoint_export_subdir
metric_name = params.trainer.best_checkpoint_eval_metric
metric_comp = params.trainer.best_checkpoint_metric_comp
if data_dir and export_subdir and metric_name:
best_ckpt_dir = os.path.join(data_dir, export_subdir)
best_ckpt_exporter = BestCheckpointExporter(
best_ckpt_dir, metric_name, metric_comp)
logging.info(
'Created the best checkpoint exporter. '
'data_dir: %s, export_subdir: %s, metric_name: %s', data_dir,
export_subdir, metric_name)
else:
best_ckpt_exporter = None
return best_ckpt_exporter
def run_experiment(distribution_strategy: tf.distribute.Strategy, def run_experiment(distribution_strategy: tf.distribute.Strategy,
...@@ -83,7 +64,8 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy, ...@@ -83,7 +64,8 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
task, task,
train='train' in mode, train='train' in mode,
evaluate=('eval' in mode) or run_post_eval, evaluate=('eval' in mode) or run_post_eval,
checkpoint_exporter=maybe_create_best_ckpt_exporter(params, model_dir)) checkpoint_exporter=maybe_create_best_ckpt_exporter(
params, model_dir))
if trainer.checkpoint: if trainer.checkpoint:
checkpoint_manager = tf.train.CheckpointManager( checkpoint_manager = tf.train.CheckpointManager(
......
...@@ -17,7 +17,7 @@ import copy ...@@ -17,7 +17,7 @@ import copy
import json import json
import os import os
import pprint import pprint
from typing import List, Optional from typing import Any, Callable, Dict, List, Optional
from absl import logging from absl import logging
import dataclasses import dataclasses
...@@ -32,6 +32,75 @@ from official.core import exp_factory ...@@ -32,6 +32,75 @@ from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
def get_leaf_nested_dict(
d: Dict[str, Any], keys: List[str]) -> Dict[str, Any]:
"""Get leaf from a dictionary with arbitrary depth with a list of keys.
Args:
d: The dictionary to extract value from.
keys: The list of keys to extract values recursively.
Returns:
The value of the leaf.
Raises:
KeyError: If the value of keys extracted is a dictionary.
"""
leaf = d
for k in keys:
if not isinstance(leaf, dict) or k not in leaf:
raise KeyError(
'Path not exist while traversing the dictionary: d with keys'
': %s.' % keys)
leaf = leaf[k]
if isinstance(leaf, dict):
raise KeyError('The value extracted with keys: %s is not a leaf of the '
'dictionary: %s.' % (keys, d))
return leaf
def cast_leaf_nested_dict(
d: Dict[str, Any],
cast_fn: Callable[[Any], Any]) -> Dict[str, Any]:
"""Cast the leaves of a dictionary with arbitrary depth in place.
Args:
d: The dictionary to extract value from.
cast_fn: The casting function.
Returns:
A dictionray with the same structure as d.
"""
for key, value in d.items():
if isinstance(value, dict):
d[key] = cast_leaf_nested_dict(value, cast_fn)
else:
d[key] = cast_fn(value)
return d
def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig,
data_dir: str) -> Any:
"""Maybe create a BestCheckpointExporter object, according to the config."""
export_subdir = params.trainer.best_checkpoint_export_subdir
metric_name = params.trainer.best_checkpoint_eval_metric
metric_comp = params.trainer.best_checkpoint_metric_comp
if data_dir and export_subdir and metric_name:
best_ckpt_dir = os.path.join(data_dir, export_subdir)
best_ckpt_exporter = BestCheckpointExporter(
best_ckpt_dir, metric_name, metric_comp)
logging.info(
'Created the best checkpoint exporter. '
'data_dir: %s, export_subdir: %s, metric_name: %s', data_dir,
export_subdir, metric_name)
else:
best_ckpt_exporter = None
return best_ckpt_exporter
# TODO(b/180147589): Add tests for this module.
class BestCheckpointExporter: class BestCheckpointExporter:
"""Keeps track of the best result, and saves its checkpoint. """Keeps track of the best result, and saves its checkpoint.
...@@ -45,17 +114,32 @@ class BestCheckpointExporter: ...@@ -45,17 +114,32 @@ class BestCheckpointExporter:
Args: Args:
export_dir: The directory that will contain exported checkpoints. export_dir: The directory that will contain exported checkpoints.
metric_name: Indicates which metric to look at, when determining which metric_name: Indicates which metric to look at, when determining which
result is better. result is better. If eval_logs being passed to maybe_export_checkpoint
is a nested dictionary, use `|` as a seperator for different layers.
metric_comp: Indicates how to compare results. Either `lower` or `higher`. metric_comp: Indicates how to compare results. Either `lower` or `higher`.
""" """
self._export_dir = export_dir self._export_dir = export_dir
self._metric_name = metric_name self._metric_name = metric_name.split('|')
self._metric_comp = metric_comp self._metric_comp = metric_comp
if self._metric_comp not in ('lower', 'higher'): if self._metric_comp not in ('lower', 'higher'):
raise ValueError('best checkpoint metric comp must be one of ' raise ValueError('best checkpoint metric comp must be one of '
'higher, lower. Got: {}'.format(self._metric_comp)) 'higher, lower. Got: {}'.format(self._metric_comp))
tf.io.gfile.makedirs(os.path.dirname(self.best_ckpt_logs_path)) tf.io.gfile.makedirs(os.path.dirname(self.best_ckpt_logs_path))
self._best_ckpt_logs = self._maybe_load_best_eval_metric() self._best_ckpt_logs = self._maybe_load_best_eval_metric()
self._checkpoint_manager = None
def _get_checkpoint_manager(self, checkpoint):
"""Gets an existing checkpoint manager or creates a new one."""
if self._checkpoint_manager is None or (
self._checkpoint_manager.checkpoint != checkpoint):
logging.info('Creates a new checkpoint manager.')
self._checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=self._export_dir,
max_to_keep=1,
checkpoint_name='best_ckpt')
return self._checkpoint_manager
def maybe_export_checkpoint(self, checkpoint, eval_logs, global_step): def maybe_export_checkpoint(self, checkpoint, eval_logs, global_step):
logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d', logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d',
...@@ -74,12 +158,10 @@ class BestCheckpointExporter: ...@@ -74,12 +158,10 @@ class BestCheckpointExporter:
def _new_metric_is_better(self, old_logs, new_logs): def _new_metric_is_better(self, old_logs, new_logs):
"""Check if the metric in new_logs is better than the metric in old_logs.""" """Check if the metric in new_logs is better than the metric in old_logs."""
if self._metric_name not in old_logs or self._metric_name not in new_logs: old_value = float(orbit.utils.get_value(
raise KeyError('best checkpoint eval metric name {} is not valid. ' get_leaf_nested_dict(old_logs, self._metric_name)))
'old_logs: {}, new_logs: {}'.format( new_value = float(orbit.utils.get_value(
self._metric_name, old_logs, new_logs)) get_leaf_nested_dict(new_logs, self._metric_name)))
old_value = float(orbit.utils.get_value(old_logs[self._metric_name]))
new_value = float(orbit.utils.get_value(new_logs[self._metric_name]))
logging.info('[BestCheckpointExporter] comparing results. old: %f, new: %f', logging.info('[BestCheckpointExporter] comparing results. old: %f, new: %f',
old_value, new_value) old_value, new_value)
...@@ -99,16 +181,13 @@ class BestCheckpointExporter: ...@@ -99,16 +181,13 @@ class BestCheckpointExporter:
"""Export evaluation results of the best checkpoint into a json file.""" """Export evaluation results of the best checkpoint into a json file."""
eval_logs_ext = copy.copy(eval_logs) eval_logs_ext = copy.copy(eval_logs)
eval_logs_ext['best_ckpt_global_step'] = global_step eval_logs_ext['best_ckpt_global_step'] = global_step
for name, value in eval_logs_ext.items(): eval_logs_ext = cast_leaf_nested_dict(
eval_logs_ext[name] = float(orbit.utils.get_value(value)) eval_logs_ext, lambda x: float(orbit.utils.get_value(x)))
# Saving json file is very fast. # Saving json file is very fast.
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer: with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer:
writer.write(json.dumps(eval_logs_ext, indent=4) + '\n') writer.write(json.dumps(eval_logs_ext, indent=4) + '\n')
# Saving the best checkpoint might be interrupted if the job got killed. self._get_checkpoint_manager(checkpoint).save()
for file_to_remove in tf.io.gfile.glob(self.best_ckpt_path + '*'):
tf.io.gfile.remove(file_to_remove)
checkpoint.write(self.best_ckpt_path)
@property @property
def best_ckpt_logs(self): def best_ckpt_logs(self):
...@@ -120,7 +199,8 @@ class BestCheckpointExporter: ...@@ -120,7 +199,8 @@ class BestCheckpointExporter:
@property @property
def best_ckpt_path(self): def best_ckpt_path(self):
return os.path.join(self._export_dir, 'best_ckpt') """Returns the best ckpt path or None if there is no ckpt yet."""
return tf.train.latest_checkpoint(self._export_dir)
@gin.configurable @gin.configurable
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.core.train_utils."""
import tensorflow as tf
from official.core import train_utils
class TrainUtilsTest(tf.test.TestCase):
def test_get_leaf_nested_dict(self):
d = {'a': {'i': {'x': 5}}}
self.assertEqual(train_utils.get_leaf_nested_dict(d, ['a', 'i', 'x']), 5)
def test_get_leaf_nested_dict_not_leaf(self):
with self.assertRaisesRegex(KeyError, 'The value extracted with keys.*'):
d = {'a': {'i': {'x': 5}}}
train_utils.get_leaf_nested_dict(d, ['a', 'i'])
def test_get_leaf_nested_dict_path_not_exist_missing_key(self):
with self.assertRaisesRegex(KeyError, 'Path not exist while traversing .*'):
d = {'a': {'i': {'x': 5}}}
train_utils.get_leaf_nested_dict(d, ['a', 'i', 'y'])
def test_get_leaf_nested_dict_path_not_exist_out_of_range(self):
with self.assertRaisesRegex(KeyError, 'Path not exist while traversing .*'):
d = {'a': {'i': {'x': 5}}}
train_utils.get_leaf_nested_dict(d, ['a', 'i', 'z'])
def test_get_leaf_nested_dict_path_not_exist_meets_leaf(self):
with self.assertRaisesRegex(KeyError, 'Path not exist while traversing .*'):
d = {'a': {'i': 5}}
train_utils.get_leaf_nested_dict(d, ['a', 'i', 'z'])
def test_cast_leaf_nested_dict(self):
d = {'a': {'i': {'x': '123'}}, 'b': 456.5}
d = train_utils.cast_leaf_nested_dict(d, int)
self.assertEqual(d['a']['i']['x'], 123)
self.assertEqual(d['b'], 456)
if __name__ == '__main__':
tf.test.main()
...@@ -37,16 +37,10 @@ class MultiTaskConfig(hyperparams.Config): ...@@ -37,16 +37,10 @@ class MultiTaskConfig(hyperparams.Config):
@dataclasses.dataclass @dataclasses.dataclass
class MultiEvalExperimentConfig(hyperparams.Config): class MultiEvalExperimentConfig(cfg.ExperimentConfig):
"""An experiment config for single-task training and multi-task evaluation. """An experiment config for single-task training and multi-task evaluation.
Attributes: Attributes:
task: the single-stream training task.
eval_tasks: individual evaluation tasks. eval_tasks: individual evaluation tasks.
trainer: the trainer configuration.
runtime: the runtime configuration.
""" """
task: cfg.TaskConfig = cfg.TaskConfig()
eval_tasks: MultiTaskConfig = MultiTaskConfig() eval_tasks: MultiTaskConfig = MultiTaskConfig()
trainer: cfg.TrainerConfig = cfg.TrainerConfig()
runtime: cfg.RuntimeConfig = cfg.RuntimeConfig()
...@@ -21,6 +21,7 @@ import gin ...@@ -21,6 +21,7 @@ import gin
import orbit import orbit
import tensorflow as tf import tensorflow as tf
from official.core import train_utils
from official.modeling.multitask import base_model from official.modeling.multitask import base_model
from official.modeling.multitask import multitask from official.modeling.multitask import multitask
...@@ -29,16 +30,20 @@ from official.modeling.multitask import multitask ...@@ -29,16 +30,20 @@ from official.modeling.multitask import multitask
class MultiTaskEvaluator(orbit.AbstractEvaluator): class MultiTaskEvaluator(orbit.AbstractEvaluator):
"""Implements the common trainer shared for TensorFlow models.""" """Implements the common trainer shared for TensorFlow models."""
def __init__(self, def __init__(
task: multitask.MultiTask, self,
model: Union[tf.keras.Model, base_model.MultiTaskBaseModel], task: multitask.MultiTask,
global_step: Optional[tf.Variable] = None): model: Union[tf.keras.Model, base_model.MultiTaskBaseModel],
global_step: Optional[tf.Variable] = None,
checkpoint_exporter: Optional[train_utils.BestCheckpointExporter] = None):
"""Initialize common trainer for TensorFlow models. """Initialize common trainer for TensorFlow models.
Args: Args:
task: A multitask.MultiTask instance. task: A multitask.MultiTask instance.
model: tf.keras.Model instance. model: tf.keras.Model instance.
global_step: the global step variable. global_step: the global step variable.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
""" """
# Gets the current distribution strategy. If not inside any strategy scope, # Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy. # it gets a single-replica no-op strategy.
...@@ -46,19 +51,10 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator): ...@@ -46,19 +51,10 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
self._task = task self._task = task
self._model = model self._model = model
self._global_step = global_step or orbit.utils.create_global_step() self._global_step = global_step or orbit.utils.create_global_step()
# TODO(hongkuny): Define a more robust way to handle the training/eval self._checkpoint_exporter = checkpoint_exporter
# checkpoint loading.
if hasattr(self.model, "checkpoint_items"):
# Each evaluation task can have different models and load a subset of
# components from the training checkpoint. This is assuming the
# checkpoint items are able to load the weights of the evaluation model.
checkpoint_items = self.model.checkpoint_items
else:
# This is assuming the evaluation model is exactly the training model.
checkpoint_items = dict(model=self.model)
self._checkpoint = tf.train.Checkpoint( self._checkpoint = tf.train.Checkpoint(
global_step=self.global_step, global_step=self.global_step,
**checkpoint_items) model=self.model)
self._validation_losses = None self._validation_losses = None
self._validation_metrics = None self._validation_metrics = None
...@@ -168,4 +164,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator): ...@@ -168,4 +164,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
metrics = task.reduce_aggregated_logs(outputs) metrics = task.reduce_aggregated_logs(outputs)
logs.update(metrics) logs.update(metrics)
results[name] = logs results[name] = logs
if self._checkpoint_exporter:
self._checkpoint_exporter.maybe_export_checkpoint(
self.checkpoint, results, self.global_step.numpy())
return results return results
...@@ -20,6 +20,7 @@ import orbit ...@@ -20,6 +20,7 @@ import orbit
import tensorflow as tf import tensorflow as tf
from official.core import base_task from official.core import base_task
from official.core import base_trainer as core_lib from official.core import base_trainer as core_lib
from official.core import train_utils
from official.modeling.multitask import configs from official.modeling.multitask import configs
from official.modeling.multitask import evaluator as evaluator_lib from official.modeling.multitask import evaluator as evaluator_lib
from official.modeling.multitask import multitask from official.modeling.multitask import multitask
...@@ -73,7 +74,9 @@ def run_experiment_with_multitask_eval( ...@@ -73,7 +74,9 @@ def run_experiment_with_multitask_eval(
evaluator = evaluator_lib.MultiTaskEvaluator( evaluator = evaluator_lib.MultiTaskEvaluator(
task=eval_tasks, task=eval_tasks,
model=model, model=model,
global_step=trainer.global_step if is_training else None) global_step=trainer.global_step if is_training else None,
checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter(
params, model_dir))
else: else:
evaluator = None evaluator = None
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A script to export the ALBERT core model as a TF-Hub SavedModel."""
# Import libraries
from absl import app
from absl import flags
import tensorflow as tf
from typing import Text
from official.nlp.albert import configs
from official.nlp.bert import bert_models
FLAGS = flags.FLAGS
flags.DEFINE_string("albert_config_file", None,
"Albert configuration file to define core albert layers.")
flags.DEFINE_string("model_checkpoint_path", None,
"File path to TF model checkpoint.")
flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.")
flags.DEFINE_string(
"sp_model_file", None,
"The sentence piece model file that the ALBERT model was trained on.")
def create_albert_model(
albert_config: configs.AlbertConfig) -> tf.keras.Model:
"""Creates an ALBERT keras core model from ALBERT configuration.
Args:
albert_config: An `AlbertConfig` to create the core model.
Returns:
A keras model.
"""
# Adds input layers just as placeholders.
input_word_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_word_ids")
input_mask = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_mask")
input_type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name="input_type_ids")
transformer_encoder = bert_models.get_transformer_encoder(
albert_config, sequence_length=None)
sequence_output, pooled_output = transformer_encoder(
[input_word_ids, input_mask, input_type_ids])
# To keep consistent with legacy hub modules, the outputs are
# "pooled_output" and "sequence_output".
return tf.keras.Model(
inputs=[input_word_ids, input_mask, input_type_ids],
outputs=[pooled_output, sequence_output]), transformer_encoder
def export_albert_tfhub(albert_config: configs.AlbertConfig,
model_checkpoint_path: Text, hub_destination: Text,
sp_model_file: Text):
"""Restores a tf.keras.Model and saves for TF-Hub."""
core_model, encoder = create_albert_model(albert_config)
checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.restore(model_checkpoint_path).assert_consumed()
core_model.sp_model_file = tf.saved_model.Asset(sp_model_file)
core_model.save(hub_destination, include_optimizer=False, save_format="tf")
def main(_):
albert_config = configs.AlbertConfig.from_json_file(
FLAGS.albert_config_file)
export_albert_tfhub(albert_config, FLAGS.model_checkpoint_path,
FLAGS.export_path, FLAGS.sp_model_file)
if __name__ == "__main__":
app.run(main)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests official.nlp.albert.export_albert_tfhub."""
import os
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
from official.nlp.albert import configs
from official.nlp.albert import export_albert_tfhub
class ExportAlbertTfhubTest(tf.test.TestCase):
def test_export_albert_tfhub(self):
# Exports a savedmodel for TF-Hub
albert_config = configs.AlbertConfig(
vocab_size=100,
embedding_size=8,
hidden_size=16,
intermediate_size=32,
max_position_embeddings=128,
num_attention_heads=2,
num_hidden_layers=1)
bert_model, encoder = export_albert_tfhub.create_albert_model(albert_config)
model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
checkpoint = tf.train.Checkpoint(model=encoder)
checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
sp_model_file = os.path.join(self.get_temp_dir(), "sp_tokenizer.model")
with tf.io.gfile.GFile(sp_model_file, "w") as f:
f.write("dummy content")
hub_destination = os.path.join(self.get_temp_dir(), "hub")
export_albert_tfhub.export_albert_tfhub(
albert_config,
model_checkpoint_path,
hub_destination,
sp_model_file=sp_model_file)
# Restores a hub KerasLayer.
hub_layer = hub.KerasLayer(hub_destination, trainable=True)
if hasattr(hub_layer, "resolved_object"):
with tf.io.gfile.GFile(
hub_layer.resolved_object.sp_model_file.asset_path.numpy()) as f:
self.assertEqual("dummy content", f.read())
# Checks the hub KerasLayer.
for source_weight, hub_weight in zip(bert_model.trainable_weights,
hub_layer.trainable_weights):
self.assertAllClose(source_weight.numpy(), hub_weight.numpy())
dummy_ids = np.zeros((2, 10), dtype=np.int32)
hub_outputs = hub_layer([dummy_ids, dummy_ids, dummy_ids])
source_outputs = bert_model([dummy_ids, dummy_ids, dummy_ids])
# The outputs of hub module are "pooled_output" and "sequence_output",
# while the outputs of encoder is in reversed order, i.e.,
# "sequence_output" and "pooled_output".
encoder_outputs = reversed(encoder([dummy_ids, dummy_ids, dummy_ids]))
self.assertEqual(hub_outputs[0].shape, (2, 16))
self.assertEqual(hub_outputs[1].shape, (2, 10, 16))
for source_output, hub_output, encoder_output in zip(
source_outputs, hub_outputs, encoder_outputs):
self.assertAllClose(source_output.numpy(), hub_output.numpy())
self.assertAllClose(source_output.numpy(), encoder_output.numpy())
if __name__ == "__main__":
tf.test.main()
...@@ -65,6 +65,7 @@ ALBERT_NAME_REPLACEMENTS = ( ...@@ -65,6 +65,7 @@ ALBERT_NAME_REPLACEMENTS = (
("ffn_1/intermediate/output/dense", "output"), ("ffn_1/intermediate/output/dense", "output"),
("transformer/LayerNorm_1/", "transformer/output_layer_norm/"), ("transformer/LayerNorm_1/", "transformer/output_layer_norm/"),
("pooler/dense", "pooler_transform"), ("pooler/dense", "pooler_transform"),
("cls/predictions", "bert/cls/predictions"),
("cls/predictions/output_bias", "cls/predictions/output_bias/bias"), ("cls/predictions/output_bias", "cls/predictions/output_bias/bias"),
("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"), ("cls/seq_relationship/output_bias", "predictions/transform/logits/bias"),
("cls/seq_relationship/output_weights", ("cls/seq_relationship/output_weights",
...@@ -113,6 +114,8 @@ def _create_pretrainer_model(cfg): ...@@ -113,6 +114,8 @@ def _create_pretrainer_model(cfg):
mlm_activation=tf_utils.get_activation(cfg.hidden_act), mlm_activation=tf_utils.get_activation(cfg.hidden_act),
mlm_initializer=tf.keras.initializers.TruncatedNormal( mlm_initializer=tf.keras.initializers.TruncatedNormal(
stddev=cfg.initializer_range)) stddev=cfg.initializer_range))
# Makes sure masked_lm layer's variables in pretrainer are created.
_ = pretrainer(pretrainer.inputs)
return pretrainer return pretrainer
......
...@@ -12,14 +12,19 @@ ...@@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""A script to export the BERT core model as a TF-Hub SavedModel.""" """A script to export BERT as a TF-Hub SavedModel.
This script is **DEPRECATED** for exporting BERT encoder models;
see the error message in by main() for details.
"""
from typing import Text
# Import libraries # Import libraries
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from typing import Text
from official.nlp.bert import bert_models from official.nlp.bert import bert_models
from official.nlp.bert import configs from official.nlp.bert import configs
...@@ -112,6 +117,14 @@ def export_bert_squad_tfhub(bert_config: configs.BertConfig, ...@@ -112,6 +117,14 @@ def export_bert_squad_tfhub(bert_config: configs.BertConfig,
def main(_): def main(_):
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file) bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
if FLAGS.model_type == "encoder": if FLAGS.model_type == "encoder":
deprecation_note = (
"nlp/bert/export_tfhub is **DEPRECATED** for exporting BERT encoder "
"models. Please switch to nlp/tools/export_tfhub for exporting BERT "
"(and other) encoders with dict inputs/outputs conforming to "
"https://www.tensorflow.org/hub/common_saved_model_apis/text#transformer-encoders"
)
logging.error(deprecation_note)
print("\n\nNOTICE:", deprecation_note, "\n")
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path, export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path,
FLAGS.export_path, FLAGS.vocab_file, FLAGS.do_lower_case) FLAGS.export_path, FLAGS.vocab_file, FLAGS.do_lower_case)
elif FLAGS.model_type == "squad": elif FLAGS.model_type == "squad":
......
...@@ -116,7 +116,13 @@ def create_v2_checkpoint(model, ...@@ -116,7 +116,13 @@ def create_v2_checkpoint(model,
"""Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint.""" """Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
# Uses streaming-restore in eager model to read V1 name-based checkpoints. # Uses streaming-restore in eager model to read V1 name-based checkpoints.
model.load_weights(src_checkpoint).assert_existing_objects_matched() model.load_weights(src_checkpoint).assert_existing_objects_matched()
checkpoint = tf.train.Checkpoint(**{checkpoint_model_name: model}) if hasattr(model, "checkpoint_items"):
checkpoint_items = model.checkpoint_items
else:
checkpoint_items = {}
checkpoint_items[checkpoint_model_name] = model
checkpoint = tf.train.Checkpoint(**checkpoint_items)
checkpoint.save(output_path) checkpoint.save(output_path)
......
...@@ -16,3 +16,4 @@ ...@@ -16,3 +16,4 @@
"""Experiments definition.""" """Experiments definition."""
# pylint: disable=unused-import # pylint: disable=unused-import
from official.nlp.configs import finetuning_experiments from official.nlp.configs import finetuning_experiments
from official.nlp.configs import pretraining_experiments
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Pretraining experiment configurations."""
# pylint: disable=g-doc-return-or-yield,line-too-long
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import optimization
from official.nlp.data import pretrain_dataloader
from official.nlp.tasks import masked_lm
@exp_factory.register_config_factory('bert/pretraining')
def bert_pretraining() -> cfg.ExperimentConfig:
"""BERT pretraining experiment."""
config = cfg.ExperimentConfig(
task=masked_lm.MaskedLMConfig(
train_data=pretrain_dataloader.BertPretrainDataConfig(),
validation_data=pretrain_dataloader.BertPretrainDataConfig(
is_training=False)),
trainer=cfg.TrainerConfig(
train_steps=1000000,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate':
0.01,
'exclude_from_weight_decay': [
'LayerNorm', 'layer_norm', 'bias'
],
}
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 1e-4,
'end_learning_rate': 0.0,
}
},
'warmup': {
'type': 'polynomial'
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment