"megatron/vscode:/vscode.git/clone" did not exist on "7ac342b704d4cf6d5391bb6f6ef32cba51cc8972"
Commit 838339f6 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Update fine_tuning_bert.ipynb colab to use tf-models-official==2.4.0 and use...

Update fine_tuning_bert.ipynb colab to use tf-models-official==2.4.0 and use the new checkpoint and tfhub model.

PiperOrigin-RevId: 357219857
parent 2e398eca
......@@ -3,7 +3,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "vXLA5InzXydn"
},
"source": [
......@@ -15,8 +14,6 @@
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "RuRlpLL-X0R_"
},
"outputs": [],
......@@ -37,7 +34,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "1mLJmVotXs64"
},
"source": [
......@@ -47,7 +43,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "hYEwGTeCXnnX"
},
"source": [
......@@ -73,7 +68,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "YN2ACivEPxgD"
},
"source": [
......@@ -85,7 +79,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "s2d9S2CSSO1z"
},
"source": [
......@@ -95,7 +88,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "fsACVQpVSifi"
},
"source": [
......@@ -110,19 +102,16 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "NvNr2svBM-p3"
},
"outputs": [],
"source": [
"!pip install -q tf-models-official==2.3.0"
"!pip install -q tf-models-official==2.4.0"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "U-7qPCjWUAyy"
},
"source": [
......@@ -133,8 +122,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "lXsXev5MNr20"
},
"outputs": [],
......@@ -163,13 +150,12 @@
"import official.nlp.data.classifier_data_lib\n",
"import official.nlp.modeling.losses\n",
"import official.nlp.modeling.models\n",
"import official.nlp.modeling.networks"
"import official.nlp.modeling.networks\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "mbanlzTvJBsz"
},
"source": [
......@@ -179,7 +165,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "PpW0x8TpR8DT"
},
"source": [
......@@ -190,20 +175,17 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "vzRHOLciR8eq"
},
"outputs": [],
"source": [
"gs_folder_bert = \"gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12\"\n",
"gs_folder_bert = \"gs://cloud-tpu-checkpoints/bert/v3/uncased_L-12_H-768_A-12\"\n",
"tf.io.gfile.listdir(gs_folder_bert)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "9uFskufsR2LT"
},
"source": [
......@@ -214,19 +196,16 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "e0dAkUttJAzj"
},
"outputs": [],
"source": [
"hub_url_bert = \"https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2\""
"hub_url_bert = \"https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3\""
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Qv6abtRvH4xO"
},
"source": [
......@@ -239,7 +218,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "28DvUhC1YUiB"
},
"source": [
......@@ -257,8 +235,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Ijikx5OsH9AT"
},
"outputs": [],
......@@ -272,8 +248,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "xf9zz4vLYXjr"
},
"outputs": [],
......@@ -284,7 +258,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "ZgBg2r2nYT-K"
},
"source": [
......@@ -295,8 +268,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "IQrHxv7W7jH5"
},
"outputs": [],
......@@ -307,7 +278,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "vhsVWYNxazz5"
},
"source": [
......@@ -318,8 +288,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "n0gfc_VTayfQ"
},
"outputs": [],
......@@ -330,7 +298,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "38zJcap6xkbC"
},
"source": [
......@@ -341,8 +308,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "xON_i6SkwApW"
},
"outputs": [],
......@@ -356,7 +321,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "9fbTyfJpNr7x"
},
"source": [
......@@ -366,7 +330,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "wqeN54S61ZKQ"
},
"source": [
......@@ -381,8 +344,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "idxyhmrCQcw5"
},
"outputs": [],
......@@ -398,7 +359,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "zYHDSquU2lDU"
},
"source": [
......@@ -409,8 +369,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "L_OfOYPg853R"
},
"outputs": [],
......@@ -424,7 +382,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "kkAXLtuyWWDI"
},
"source": [
......@@ -438,7 +395,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "62UTWLQd9-LB"
},
"source": [
......@@ -451,8 +407,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "bdL-dRNRBRJT"
},
"outputs": [],
......@@ -463,7 +417,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "UrPktnqpwqie"
},
"source": [
......@@ -474,8 +427,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "BR7BmtU498Bh"
},
"outputs": [],
......@@ -495,8 +446,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "has42aUdfky-"
},
"outputs": [],
......@@ -508,7 +457,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "MU9lTWy_xXbb"
},
"source": [
......@@ -519,8 +467,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "USD8uihw-g4J"
},
"outputs": [],
......@@ -533,7 +479,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "xmNv4l4k-dBZ"
},
"source": [
......@@ -543,7 +488,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "DIWjNIKq-ldh"
},
"source": [
......@@ -556,7 +500,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "ulNZ4U96-8JZ"
},
"source": [
......@@ -567,8 +510,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "EezOO9qj91kP"
},
"outputs": [],
......@@ -581,7 +522,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "rxLenwAvCkBf"
},
"source": [
......@@ -592,8 +532,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "2CetH_5C9P2m"
},
"outputs": [],
......@@ -609,7 +547,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "P5UBnCn8Ii6s"
},
"source": [
......@@ -622,8 +559,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "sDGiWYPLEd5a"
},
"outputs": [],
......@@ -666,8 +601,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "yuLKxf6zHxw-"
},
"outputs": [],
......@@ -685,7 +618,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "7FC5aLVxKVKK"
},
"source": [
......@@ -696,8 +628,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "jyjTdGpFhO_1"
},
"outputs": [],
......@@ -711,7 +641,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "FSwymsbkbLDA"
},
"source": [
......@@ -721,7 +650,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Efrj3Cn1kLAp"
},
"source": [
......@@ -731,7 +659,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "xxpOY5r2Ayq6"
},
"source": [
......@@ -742,8 +669,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "ujapVfZ_AKW7"
},
"outputs": [],
......@@ -761,7 +686,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "96ldxDSwkVkj"
},
"source": [
......@@ -774,8 +698,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "cH682__U0FBv"
},
"outputs": [],
......@@ -787,7 +709,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "XqKp3-5GIZlw"
},
"source": [
......@@ -798,8 +719,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "bAQblMIjwkvx"
},
"outputs": [],
......@@ -810,7 +729,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "sFmVG4SKZAw8"
},
"source": [
......@@ -821,8 +739,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "VTjgPbp4ZDKo"
},
"outputs": [],
......@@ -837,7 +753,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Q0NTdwZsQK8n"
},
"source": [
......@@ -850,8 +765,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "8L__-erBwLIQ"
},
"outputs": [],
......@@ -862,7 +775,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "mKAvkQc3heSy"
},
"source": [
......@@ -875,21 +787,18 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "97Ll2Gichd_Y"
},
"outputs": [],
"source": [
"checkpoint = tf.train.Checkpoint(model=bert_encoder)\n",
"checkpoint.restore(\n",
"checkpoint = tf.train.Checkpoint(encoder=bert_encoder)\n",
"checkpoint.read(\n",
" os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "2oHOql35k3Dd"
},
"source": [
......@@ -899,7 +808,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "115caFLMk-_l"
},
"source": [
......@@ -913,8 +821,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "w8qXKRZuCwW4"
},
"outputs": [],
......@@ -937,7 +843,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "pXRGxiRNEHS2"
},
"source": [
......@@ -948,8 +853,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "eQNA16bhDpky"
},
"outputs": [],
......@@ -960,7 +863,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "xqu_K71fJQB8"
},
"source": [
......@@ -970,7 +872,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "78FEUOOEkoP0"
},
"source": [
......@@ -980,7 +881,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "OTNcA0O0nSq9"
},
"source": [
......@@ -991,8 +891,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "nzi8hjeTQTRs"
},
"outputs": [],
......@@ -1015,7 +913,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "IFtKFWbNKb0u"
},
"source": [
......@@ -1028,8 +925,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "9ZoUgDUNJPz3"
},
"outputs": [],
......@@ -1049,7 +944,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "7ynJibkBRTJF"
},
"source": [
......@@ -1060,8 +954,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "umo0ttrgRYIM"
},
"outputs": [],
......@@ -1076,8 +968,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "utGl0M3aZCE4"
},
"outputs": [],
......@@ -1088,7 +978,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "fVo_AnT0l26j"
},
"source": [
......@@ -1101,8 +990,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Nl5x6nElZqkP"
},
"outputs": [],
......@@ -1115,8 +1002,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"collapsed": true,
"id": "y_ACvKPsVUXC"
},
"outputs": [],
......@@ -1137,7 +1023,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "eQceYqRFT_Eg"
},
"source": [
......@@ -1147,7 +1032,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "SaC1RlFawUpc"
},
"source": [
......@@ -1158,7 +1042,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "CwUdjFBkzUgh"
},
"source": [
......@@ -1170,7 +1053,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "2UTQrkyOT5wD"
},
"source": [
......@@ -1181,8 +1063,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "XQeDFOzYR9Z9"
},
"outputs": [],
......@@ -1195,7 +1075,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "XrFQbfErUWxa"
},
"source": [
......@@ -1206,8 +1085,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "ymw7GOHpSHKU"
},
"outputs": [],
......@@ -1234,7 +1111,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "uX_Sp-wTUoRm"
},
"source": [
......@@ -1245,8 +1121,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "rkHxIK57SQ_r"
},
"outputs": [],
......@@ -1267,7 +1141,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "stbaVouogvzS"
},
"source": [
......@@ -1278,8 +1151,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "gwhrlQl4gxVF"
},
"outputs": [],
......@@ -1290,7 +1161,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "dbJ76vSJj77j"
},
"source": [
......@@ -1300,7 +1170,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "9J95LFRohiYw"
},
"source": [
......@@ -1311,8 +1180,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "gCvaLLAxPuMc"
},
"outputs": [],
......@@ -1356,8 +1223,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "rutkBadrhzdR"
},
"outputs": [],
......@@ -1384,8 +1249,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "59TVgt4Z7fuU"
},
"outputs": [],
......@@ -1396,7 +1259,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "QbklKt-w_CiI"
},
"source": [
......@@ -1411,8 +1273,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "GDWrHm0BGpbX"
},
"outputs": [],
......@@ -1426,8 +1286,6 @@
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {},
"colab_type": "code",
"id": "Y29meH0qGq_5"
},
"outputs": [],
......@@ -1439,13 +1297,11 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "lo6479At4sP1"
},
"outputs": [],
"source": [
"hub_encoder = hub.KerasLayer(f\"https://tfhub.dev/tensorflow/{hub_model_name}/2\",\n",
"hub_encoder = hub.KerasLayer(f\"https://tfhub.dev/tensorflow/{hub_model_name}/3\",\n",
" trainable=True)\n",
"\n",
"print(f\"The Hub encoder has {len(hub_encoder.trainable_variables)} trainable variables\")"
......@@ -1454,7 +1310,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "iTzF574wivQv"
},
"source": [
......@@ -1465,27 +1320,25 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "XEcYrCR45Uwo"
},
"outputs": [],
"source": [
"result = hub_encoder(\n",
" inputs=[glue_train['input_word_ids'][:10],\n",
" glue_train['input_mask'][:10],\n",
" glue_train['input_type_ids'][:10],],\n",
" inputs=dict(\n",
" input_word_ids=glue_train['input_word_ids'][:10],\n",
" input_mask=glue_train['input_mask'][:10],\n",
" input_type_ids=glue_train['input_type_ids'][:10],),\n",
" training=False,\n",
")\n",
"\n",
"print(\"Pooled output shape:\", result[0].shape)\n",
"print(\"Sequence output shape:\", result[1].shape)"
"print(\"Pooled output shape:\", result['pooled_output'].shape)\n",
"print(\"Sequence output shape:\", result['sequence_output'].shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "cjojn8SmLSRI"
},
"source": [
......@@ -1498,33 +1351,31 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "9nTDaApyLR70"
},
"outputs": [],
"source": [
"hub_classifier, hub_encoder = bert.bert_models.classifier_model(\n",
" # Caution: Most of `bert_config` is ignored if you pass a hub url.\n",
" bert_config=bert_config, hub_module_url=hub_url_bert, num_labels=2)"
"hub_classifier = nlp.modeling.models.BertClassifier(\n",
" bert_encoder,\n",
" num_classes=2,\n",
" dropout_rate=0.1,\n",
" initializer=tf.keras.initializers.TruncatedNormal(\n",
" stddev=0.02))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "xMJX3wV0_v7I"
},
"source": [
"The one downside to loading this model from TFHub is that the structure of internal keras layers is not restored. So it's more difficult to inspect or modify the model. The `TransformerEncoder` model is now a single layer:"
"The one downside to loading this model from TFHub is that the structure of internal keras layers is not restored. So it's more difficult to inspect or modify the model. The `BertEncoder` model is now a single layer:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "pD71dnvhM2QS"
},
"outputs": [],
......@@ -1536,8 +1387,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "nLZD-isBzNKi"
},
"outputs": [],
......@@ -1552,7 +1401,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "ZxSqH0dNAgXV"
},
"source": [
......@@ -1560,13 +1408,12 @@
"\n",
"### Low level model building\n",
"\n",
"If you need a more control over the construction of the model it's worth noting that the `classifier_model` function used earlier is really just a thin wrapper over the `nlp.modeling.networks.TransformerEncoder` and `nlp.modeling.models.BertClassifier` classes. Just remember that if you start modifying the architecture it may not be correct or possible to reload the pre-trained checkpoint so you'll need to retrain from scratch."
"If you need a more control over the construction of the model it's worth noting that the `classifier_model` function used earlier is really just a thin wrapper over the `nlp.modeling.networks.BertEncoder` and `nlp.modeling.models.BertClassifier` classes. Just remember that if you start modifying the architecture it may not be correct or possible to reload the pre-trained checkpoint so you'll need to retrain from scratch."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "0cgABEwDj06P"
},
"source": [
......@@ -1577,43 +1424,38 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "5r_yqhBFSVEM"
},
"outputs": [],
"source": [
"transformer_config = config_dict.copy()\n",
"bert_encoder_config = config_dict.copy()\n",
"\n",
"# You need to rename a few fields to make this work:\n",
"transformer_config['attention_dropout_rate'] = transformer_config.pop('attention_probs_dropout_prob')\n",
"transformer_config['activation'] = tf_utils.get_activation(transformer_config.pop('hidden_act'))\n",
"transformer_config['dropout_rate'] = transformer_config.pop('hidden_dropout_prob')\n",
"transformer_config['initializer'] = tf.keras.initializers.TruncatedNormal(\n",
" stddev=transformer_config.pop('initializer_range'))\n",
"transformer_config['max_sequence_length'] = transformer_config.pop('max_position_embeddings')\n",
"transformer_config['num_layers'] = transformer_config.pop('num_hidden_layers')\n",
"bert_encoder_config['attention_dropout_rate'] = bert_encoder_config.pop('attention_probs_dropout_prob')\n",
"bert_encoder_config['activation'] = tf_utils.get_activation(bert_encoder_config.pop('hidden_act'))\n",
"bert_encoder_config['dropout_rate'] = bert_encoder_config.pop('hidden_dropout_prob')\n",
"bert_encoder_config['initializer'] = tf.keras.initializers.TruncatedNormal(\n",
" stddev=bert_encoder_config.pop('initializer_range'))\n",
"bert_encoder_config['max_sequence_length'] = bert_encoder_config.pop('max_position_embeddings')\n",
"bert_encoder_config['num_layers'] = bert_encoder_config.pop('num_hidden_layers')\n",
"\n",
"transformer_config"
"bert_encoder_config"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "rIO8MI7LLijh"
},
"outputs": [],
"source": [
"manual_encoder = nlp.modeling.networks.TransformerEncoder(**transformer_config)"
"manual_encoder = nlp.modeling.networks.BertEncoder(**bert_encoder_config)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "4a4tFSg9krRi"
},
"source": [
......@@ -1624,21 +1466,18 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "X6N9NEqfXJCx"
},
"outputs": [],
"source": [
"checkpoint = tf.train.Checkpoint(model=manual_encoder)\n",
"checkpoint.restore(\n",
"checkpoint = tf.train.Checkpoint(encoder=manual_encoder)\n",
"checkpoint.read(\n",
" os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "1BPiPO4ykuwM"
},
"source": [
......@@ -1649,8 +1488,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "hlVdgJKmj389"
},
"outputs": [],
......@@ -1664,7 +1501,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "nJMXvVgJkyBv"
},
"source": [
......@@ -1675,8 +1511,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "tQX57GJ6wkAb"
},
"outputs": [],
......@@ -1684,17 +1518,14 @@
"manual_classifier = nlp.modeling.models.BertClassifier(\n",
" bert_encoder,\n",
" num_classes=2,\n",
" dropout_rate=transformer_config['dropout_rate'],\n",
" initializer=tf.keras.initializers.TruncatedNormal(\n",
" stddev=bert_config.initializer_range))"
" dropout_rate=bert_encoder_config['dropout_rate'],\n",
" initializer=bert_encoder_config['initializer'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "kB-nBWhQk0dS"
},
"outputs": [],
......@@ -1705,7 +1536,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "E6AJlOSyIO1L"
},
"source": [
......@@ -1720,8 +1550,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "28Dv3BPRlFTD"
},
"outputs": [],
......@@ -1733,7 +1561,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "LRjcHr0UlT8c"
},
"source": [
......@@ -1746,8 +1573,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "MHY8K6kDngQn"
},
"outputs": [],
......@@ -1765,8 +1590,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"collapsed": true,
"id": "wKIcSprulu3P"
},
"outputs": [],
......@@ -1782,7 +1606,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "IMTC_gfAl_PZ"
},
"source": [
......@@ -1793,8 +1616,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "YRt3VTmBmCBY"
},
"outputs": [],
......@@ -1816,7 +1637,6 @@
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "l8D9Lv3Bn740"
},
"source": [
......@@ -1827,8 +1647,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "2Hf2rpRXk89N"
},
"outputs": [],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment