Comparing-TF-and-PT-models.ipynb 61.2 KB
Newer Older
1
2
3
4
5
6
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
VictorSanh's avatar
VictorSanh committed
7
8
    "# Comparing TensorFlow (original) and PyTorch models\n",
    "\n",
9
    "You can use this small notebook to check the conversion of the model's weights from the TensorFlow model to the PyTorch model. In the following, we compare the weights of the last layer on a simple example (in `input.txt`) but both models returns all the hidden layers so you can check every stage of the model.\n",
VictorSanh's avatar
VictorSanh committed
10
    "\n",
11
12
13
14
15
    "To run this notebook, follow these instructions:\n",
    "- make sure that your Python environment has both TensorFlow and PyTorch installed,\n",
    "- download the original TensorFlow implementation,\n",
    "- download a pre-trained TensorFlow model as indicaded in the TensorFlow implementation readme,\n",
    "- run the script `convert_tf_checkpoint_to_pytorch.py` as indicated in the `README` to convert the pre-trained TensorFlow model to PyTorch.\n",
VictorSanh's avatar
VictorSanh committed
16
    "\n",
17
    "If needed change the relative paths indicated in this notebook (at the beggining of Sections 1 and 2) to point to the relevent models and code."
VictorSanh's avatar
VictorSanh committed
18
19
   ]
  },
thomwolf's avatar
thomwolf committed
20
21
22
23
24
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
thomwolf's avatar
thomwolf committed
25
26
     "end_time": "2018-11-15T14:56:48.412622Z",
     "start_time": "2018-11-15T14:56:48.400110Z"
thomwolf's avatar
thomwolf committed
27
28
29
30
31
32
33
34
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../')"
   ]
  },
VictorSanh's avatar
VictorSanh committed
35
36
37
38
39
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1/ TensorFlow code"
40
41
42
43
   ]
  },
  {
   "cell_type": "code",
thomwolf's avatar
thomwolf committed
44
   "execution_count": 2,
45
46
   "metadata": {
    "ExecuteTime": {
thomwolf's avatar
thomwolf committed
47
48
     "end_time": "2018-11-15T14:56:49.483829Z",
     "start_time": "2018-11-15T14:56:49.471296Z"
49
50
    }
   },
VictorSanh's avatar
VictorSanh committed
51
52
   "outputs": [],
   "source": [
53
54
    "original_tf_inplem_dir = \"./tensorflow_code/\"\n",
    "model_dir = \"../google_models/uncased_L-12_H-768_A-12/\"\n",
VictorSanh's avatar
VictorSanh committed
55
56
57
58
59
    "\n",
    "vocab_file = model_dir + \"vocab.txt\"\n",
    "bert_config_file = model_dir + \"bert_config.json\"\n",
    "init_checkpoint = model_dir + \"bert_model.ckpt\"\n",
    "\n",
60
    "input_file = \"./samples/input.txt\"\n",
VictorSanh's avatar
VictorSanh committed
61
62
63
64
65
    "max_seq_length = 128"
   ]
  },
  {
   "cell_type": "code",
thomwolf's avatar
thomwolf committed
66
   "execution_count": 6,
67
68
   "metadata": {
    "ExecuteTime": {
thomwolf's avatar
thomwolf committed
69
70
     "end_time": "2018-11-15T14:57:51.597932Z",
     "start_time": "2018-11-15T14:57:51.549466Z"
71
72
    }
   },
thomwolf's avatar
thomwolf committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
   "outputs": [
    {
     "ename": "DuplicateFlagError",
     "evalue": "The flag 'input_file' is defined twice. First from *, Second from *.  Description from first occurrence: (no help available)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mDuplicateFlagError\u001b[0m                        Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-6-86ecffb49060>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0mspec\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimportlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutil\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspec_from_file_location\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'*'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moriginal_tf_inplem_dir\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m'/extract_features_tensorflow.py'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimportlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutil\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_from_spec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mspec\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mspec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloader\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexec_module\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      7\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'extract_features_tensorflow'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/importlib/_bootstrap_external.py\u001b[0m in \u001b[0;36mexec_module\u001b[0;34m(self, module)\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_call_with_frames_removed\u001b[0;34m(f, *args, **kwds)\u001b[0m\n",
      "\u001b[0;32m~/Documents/Thomas/Code/HF/BERT/pytorch-pretrained-BERT/tensorflow_code/extract_features_tensorflow.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     32\u001b[0m \u001b[0mFLAGS\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflags\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFLAGS\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 34\u001b[0;31m \u001b[0mflags\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDEFINE_string\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"input_file\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     35\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     36\u001b[0m \u001b[0mflags\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDEFINE_string\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"output_file\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/tensorflow/python/platform/flags.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     56\u001b[0m           \u001b[0;34m'Use of the keyword argument names (flag_name, default_value, '\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     57\u001b[0m           'docstring) is deprecated, please use (name, default, help) instead.')\n\u001b[0;32m---> 58\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0moriginal_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     59\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     60\u001b[0m   \u001b[0;32mreturn\u001b[0m \u001b[0mtf_decorator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake_decorator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moriginal_function\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/absl/flags/_defines.py\u001b[0m in \u001b[0;36mDEFINE_string\u001b[0;34m(name, default, help, flag_values, **args)\u001b[0m\n\u001b[1;32m    239\u001b[0m   \u001b[0mparser\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_argument_parser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mArgumentParser\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    240\u001b[0m   \u001b[0mserializer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_argument_parser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mArgumentSerializer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 241\u001b[0;31m   \u001b[0mDEFINE\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparser\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhelp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflag_values\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mserializer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    242\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    243\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/absl/flags/_defines.py\u001b[0m in \u001b[0;36mDEFINE\u001b[0;34m(parser, name, default, help, flag_values, serializer, module_name, **args)\u001b[0m\n\u001b[1;32m     80\u001b[0m   \"\"\"\n\u001b[1;32m     81\u001b[0m   DEFINE_flag(_flag.Flag(parser, serializer, name, default, help, **args),\n\u001b[0;32m---> 82\u001b[0;31m               flag_values, module_name)\n\u001b[0m\u001b[1;32m     83\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     84\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/absl/flags/_defines.py\u001b[0m in \u001b[0;36mDEFINE_flag\u001b[0;34m(flag, flag_values, module_name)\u001b[0m\n\u001b[1;32m    102\u001b[0m   \u001b[0;31m# Copying the reference to flag_values prevents pychecker warnings.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    103\u001b[0m   \u001b[0mfv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflag_values\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 104\u001b[0;31m   \u001b[0mfv\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mflag\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflag\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    105\u001b[0m   \u001b[0;31m# Tell flag_values who's defining the flag.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    106\u001b[0m   \u001b[0;32mif\u001b[0m \u001b[0mmodule_name\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/absl/flags/_flagvalues.py\u001b[0m in \u001b[0;36m__setitem__\u001b[0;34m(self, name, flag)\u001b[0m\n\u001b[1;32m    427\u001b[0m         \u001b[0;31m# module is simply being imported a subsequent time.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    428\u001b[0m         \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 429\u001b[0;31m       \u001b[0;32mraise\u001b[0m \u001b[0m_exceptions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDuplicateFlagError\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_flag\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    430\u001b[0m     \u001b[0mshort_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflag\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshort_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    431\u001b[0m     \u001b[0;31m# If a new flag overrides an old one, we need to cleanup the old flag's\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mDuplicateFlagError\u001b[0m: The flag 'input_file' is defined twice. First from *, Second from *.  Description from first occurrence: (no help available)"
     ]
    }
   ],
94
   "source": [
95
    "import importlib.util\n",
VictorSanh's avatar
VictorSanh committed
96
97
    "import sys\n",
    "\n",
thomwolf's avatar
thomwolf committed
98
    "spec = importlib.util.spec_from_file_location('*', original_tf_inplem_dir + '/extract_features_tensorflow.py')\n",
99
100
101
102
103
    "module = importlib.util.module_from_spec(spec)\n",
    "spec.loader.exec_module(module)\n",
    "sys.modules['extract_features_tensorflow'] = module\n",
    "\n",
    "from extract_features_tensorflow import *"
104
105
106
107
   ]
  },
  {
   "cell_type": "code",
thomwolf's avatar
thomwolf committed
108
   "execution_count": 8,
109
110
   "metadata": {
    "ExecuteTime": {
thomwolf's avatar
thomwolf committed
111
112
     "end_time": "2018-11-15T14:58:05.650987Z",
     "start_time": "2018-11-15T14:58:05.541620Z"
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Example ***\n",
      "INFO:tensorflow:unique_id: 0\n",
      "INFO:tensorflow:tokens: [CLS] who was jim henson ? [SEP] jim henson was a puppet ##eer [SEP]\n",
      "INFO:tensorflow:input_ids: 101 2040 2001 3958 27227 1029 102 3958 27227 2001 1037 13997 11510 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "INFO:tensorflow:input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
      "INFO:tensorflow:input_type_ids: 0 0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n"
     ]
    }
   ],
   "source": [
thomwolf's avatar
thomwolf committed
130
    "layer_indexes = list(range(12))\n",
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    "bert_config = modeling.BertConfig.from_json_file(bert_config_file)\n",
    "tokenizer = tokenization.FullTokenizer(\n",
    "    vocab_file=vocab_file, do_lower_case=True)\n",
    "examples = read_examples(input_file)\n",
    "\n",
    "features = convert_examples_to_features(\n",
    "    examples=examples, seq_length=max_seq_length, tokenizer=tokenizer)\n",
    "unique_id_to_feature = {}\n",
    "for feature in features:\n",
    "    unique_id_to_feature[feature.unique_id] = feature"
   ]
  },
  {
   "cell_type": "code",
thomwolf's avatar
thomwolf committed
145
   "execution_count": 9,
146
147
   "metadata": {
    "ExecuteTime": {
thomwolf's avatar
thomwolf committed
148
149
     "end_time": "2018-11-15T14:58:11.562443Z",
     "start_time": "2018-11-15T14:58:08.036485Z"
150
151
152
153
154
155
156
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
thomwolf's avatar
thomwolf committed
157
158
159
      "WARNING:tensorflow:Estimator's model_fn (<function model_fn_builder.<locals>.model_fn at 0x11ea7f1e0>) includes params argument, but params are not passed to Estimator.\n",
      "WARNING:tensorflow:Using temporary folder as model directory: /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmphs4_nsq9\n",
      "INFO:tensorflow:Using config: {'_model_dir': '/var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmphs4_nsq9', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n",
160
161
162
163
164
      "graph_options {\n",
      "  rewrite_options {\n",
      "    meta_optimizer_iterations: ONE\n",
      "  }\n",
      "}\n",
thomwolf's avatar
thomwolf committed
165
      ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x121b163c8>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=2, num_shards=1, num_cores_per_replica=None, per_host_input_for_training=3, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None), '_cluster': None}\n",
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
      "WARNING:tensorflow:Setting TPUConfig.num_shards==1 is an unsupported behavior. Please fix as soon as possible (leaving num_shards as None.\n",
      "INFO:tensorflow:_TPUContext: eval_on_tpu True\n",
      "WARNING:tensorflow:eval_on_tpu ignored because use_tpu is False.\n"
     ]
    }
   ],
   "source": [
    "is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2\n",
    "run_config = tf.contrib.tpu.RunConfig(\n",
    "    master=None,\n",
    "    tpu_config=tf.contrib.tpu.TPUConfig(\n",
    "        num_shards=1,\n",
    "        per_host_input_for_training=is_per_host))\n",
    "\n",
    "model_fn = model_fn_builder(\n",
    "    bert_config=bert_config,\n",
    "    init_checkpoint=init_checkpoint,\n",
    "    layer_indexes=layer_indexes,\n",
    "    use_tpu=False,\n",
    "    use_one_hot_embeddings=False)\n",
    "\n",
    "# If TPU is not available, this will fall back to normal Estimator on CPU\n",
    "# or GPU.\n",
    "estimator = tf.contrib.tpu.TPUEstimator(\n",
    "    use_tpu=False,\n",
    "    model_fn=model_fn,\n",
    "    config=run_config,\n",
    "    predict_batch_size=1)\n",
    "\n",
    "input_fn = input_fn_builder(\n",
    "    features=features, seq_length=max_seq_length)"
   ]
  },
  {
   "cell_type": "code",
thomwolf's avatar
thomwolf committed
201
   "execution_count": 10,
202
203
   "metadata": {
    "ExecuteTime": {
thomwolf's avatar
thomwolf committed
204
205
     "end_time": "2018-11-15T14:58:21.736543Z",
     "start_time": "2018-11-15T14:58:16.723829Z"
206
207
208
209
210
211
212
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
thomwolf's avatar
thomwolf committed
213
      "INFO:tensorflow:Could not find trained model in model_dir: /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmphs4_nsq9, running initialization to predict.\n",
214
215
216
217
218
219
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Running infer on CPU\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
thomwolf's avatar
thomwolf committed
220
221
222
223
224
225
226
227
228
229
230
231
      "extracting layer 0\n",
      "extracting layer 1\n",
      "extracting layer 2\n",
      "extracting layer 3\n",
      "extracting layer 4\n",
      "extracting layer 5\n",
      "extracting layer 6\n",
      "extracting layer 7\n",
      "extracting layer 8\n",
      "extracting layer 9\n",
      "extracting layer 10\n",
      "extracting layer 11\n",
232
233
234
235
236
237
      "INFO:tensorflow:prediction_loop marked as finished\n",
      "INFO:tensorflow:prediction_loop marked as finished\n"
     ]
    }
   ],
   "source": [
thomwolf's avatar
thomwolf committed
238
    "tensorflow_all_out = []\n",
239
240
241
242
243
    "for result in estimator.predict(input_fn, yield_single_examples=True):\n",
    "    unique_id = int(result[\"unique_id\"])\n",
    "    feature = unique_id_to_feature[unique_id]\n",
    "    output_json = collections.OrderedDict()\n",
    "    output_json[\"linex_index\"] = unique_id\n",
thomwolf's avatar
thomwolf committed
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
    "    tensorflow_all_out_features = []\n",
    "    # for (i, token) in enumerate(feature.tokens):\n",
    "    all_layers = []\n",
    "    for (j, layer_index) in enumerate(layer_indexes):\n",
    "        print(\"extracting layer {}\".format(j))\n",
    "        layer_output = result[\"layer_output_%d\" % j]\n",
    "        layers = collections.OrderedDict()\n",
    "        layers[\"index\"] = layer_index\n",
    "        layers[\"values\"] = layer_output\n",
    "        all_layers.append(layers)\n",
    "    tensorflow_out_features = collections.OrderedDict()\n",
    "    tensorflow_out_features[\"layers\"] = all_layers\n",
    "    tensorflow_all_out_features.append(tensorflow_out_features)\n",
    "\n",
    "    output_json[\"features\"] = tensorflow_all_out_features\n",
    "    tensorflow_all_out.append(output_json)"
260
261
262
263
   ]
  },
  {
   "cell_type": "code",
thomwolf's avatar
thomwolf committed
264
   "execution_count": 11,
thomwolf's avatar
thomwolf committed
265
266
   "metadata": {
    "ExecuteTime": {
thomwolf's avatar
thomwolf committed
267
268
     "end_time": "2018-11-15T14:58:23.970714Z",
     "start_time": "2018-11-15T14:58:23.931930Z"
thomwolf's avatar
thomwolf committed
269
270
271
272
273
274
275
276
277
278
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n",
      "odict_keys(['linex_index', 'features'])\n",
thomwolf's avatar
thomwolf committed
279
280
      "number of tokens 1\n",
      "number of layers 12\n"
thomwolf's avatar
thomwolf committed
281
     ]
thomwolf's avatar
thomwolf committed
282
283
284
285
286
287
288
    },
    {
     "data": {
      "text/plain": [
       "(128, 768)"
      ]
     },
thomwolf's avatar
thomwolf committed
289
     "execution_count": 11,
thomwolf's avatar
thomwolf committed
290
291
     "metadata": {},
     "output_type": "execute_result"
thomwolf's avatar
thomwolf committed
292
293
294
    }
   ],
   "source": [
thomwolf's avatar
thomwolf committed
295
296
297
298
299
300
    "print(len(tensorflow_all_out))\n",
    "print(len(tensorflow_all_out[0]))\n",
    "print(tensorflow_all_out[0].keys())\n",
    "print(\"number of tokens\", len(tensorflow_all_out[0]['features']))\n",
    "print(\"number of layers\", len(tensorflow_all_out[0]['features'][0]['layers']))\n",
    "tensorflow_all_out[0]['features'][0]['layers'][0]['values'].shape"
thomwolf's avatar
thomwolf committed
301
302
303
304
   ]
  },
  {
   "cell_type": "code",
thomwolf's avatar
thomwolf committed
305
   "execution_count": 12,
306
307
   "metadata": {
    "ExecuteTime": {
thomwolf's avatar
thomwolf committed
308
309
     "end_time": "2018-11-15T14:58:25.547012Z",
     "start_time": "2018-11-15T14:58:25.516076Z"
310
311
    }
   },
thomwolf's avatar
thomwolf committed
312
   "outputs": [],
313
   "source": [
thomwolf's avatar
thomwolf committed
314
    "tensorflow_outputs = list(tensorflow_all_out[0]['features'][0]['layers'][t]['values'] for t in layer_indexes)"
thomwolf's avatar
thomwolf committed
315
316
317
318
319
320
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
VictorSanh's avatar
VictorSanh committed
321
    "## 2/ PyTorch code"
322
323
324
325
   ]
  },
  {
   "cell_type": "code",
thomwolf's avatar
thomwolf committed
326
327
328
329
330
331
332
333
334
335
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.chdir('./examples')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
336
337
   "metadata": {
    "ExecuteTime": {
thomwolf's avatar
thomwolf committed
338
339
     "end_time": "2018-11-15T15:03:49.528679Z",
     "start_time": "2018-11-15T15:03:49.497697Z"
340
341
342
343
    }
   },
   "outputs": [],
   "source": [
344
    "import extract_features\n",
thomwolf's avatar
thomwolf committed
345
    "import pytorch_transformers as ppb\n",
346
    "from extract_features import *"
347
348
349
   ]
  },
  {
thomwolf's avatar
thomwolf committed
350
   "cell_type": "code",
thomwolf's avatar
thomwolf committed
351
   "execution_count": 25,
352
353
   "metadata": {
    "ExecuteTime": {
thomwolf's avatar
thomwolf committed
354
355
     "end_time": "2018-11-15T15:21:18.001177Z",
     "start_time": "2018-11-15T15:21:17.970369Z"
356
357
    }
   },
VictorSanh's avatar
VictorSanh committed
358
359
   "outputs": [],
   "source": [
thomwolf's avatar
thomwolf committed
360
    "init_checkpoint_pt = \"../../google_models/uncased_L-12_H-768_A-12/\""
VictorSanh's avatar
VictorSanh committed
361
362
363
364
   ]
  },
  {
   "cell_type": "code",
thomwolf's avatar
thomwolf committed
365
   "execution_count": 26,
thomwolf's avatar
thomwolf committed
366
367
   "metadata": {
    "ExecuteTime": {
thomwolf's avatar
thomwolf committed
368
369
     "end_time": "2018-11-15T15:21:20.893669Z",
     "start_time": "2018-11-15T15:21:18.786623Z"
VictorSanh's avatar
VictorSanh committed
370
371
    },
    "scrolled": true
thomwolf's avatar
thomwolf committed
372
373
   },
   "outputs": [
thomwolf's avatar
thomwolf committed
374
375
376
377
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
thomwolf's avatar
thomwolf committed
378
379
      "11/15/2018 16:21:18 - INFO - pytorch_transformers.modeling_bert -   loading archive file ../../google_models/uncased_L-12_H-768_A-12/\n",
      "11/15/2018 16:21:18 - INFO - pytorch_transformers.modeling_bert -   Model config {\n",
thomwolf's avatar
thomwolf committed
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"vocab_size\": 30522\n",
      "}\n",
      "\n"
     ]
    },
thomwolf's avatar
thomwolf committed
395
396
397
398
    {
     "data": {
      "text/plain": [
       "BertModel(\n",
thomwolf's avatar
thomwolf committed
399
       "  (embeddings): BertEmbeddings(\n",
thomwolf's avatar
thomwolf committed
400
401
402
       "    (word_embeddings): Embedding(30522, 768)\n",
       "    (position_embeddings): Embedding(512, 768)\n",
       "    (token_type_embeddings): Embedding(2, 768)\n",
thomwolf's avatar
thomwolf committed
403
       "    (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
404
405
       "    (dropout): Dropout(p=0.1)\n",
       "  )\n",
thomwolf's avatar
thomwolf committed
406
       "  (encoder): BertEncoder(\n",
thomwolf's avatar
thomwolf committed
407
       "    (layer): ModuleList(\n",
thomwolf's avatar
thomwolf committed
408
409
410
       "      (0): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
411
412
413
414
415
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
416
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
417
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
418
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
419
420
421
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
422
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
423
424
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
425
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
426
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
427
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
428
429
430
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
thomwolf's avatar
thomwolf committed
431
432
433
       "      (1): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
434
435
436
437
438
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
439
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
440
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
441
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
442
443
444
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
445
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
446
447
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
448
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
449
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
450
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
451
452
453
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
thomwolf's avatar
thomwolf committed
454
455
456
       "      (2): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
457
458
459
460
461
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
462
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
463
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
464
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
465
466
467
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
468
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
469
470
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
471
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
472
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
473
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
474
475
476
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
thomwolf's avatar
thomwolf committed
477
478
479
       "      (3): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
480
481
482
483
484
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
485
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
486
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
487
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
488
489
490
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
491
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
492
493
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
494
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
495
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
496
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
497
498
499
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
thomwolf's avatar
thomwolf committed
500
501
502
       "      (4): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
503
504
505
506
507
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
508
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
509
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
510
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
511
512
513
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
514
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
515
516
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
517
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
518
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
519
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
520
521
522
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
thomwolf's avatar
thomwolf committed
523
524
525
       "      (5): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
526
527
528
529
530
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
531
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
532
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
533
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
534
535
536
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
537
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
538
539
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
540
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
541
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
542
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
543
544
545
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
thomwolf's avatar
thomwolf committed
546
547
548
       "      (6): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
549
550
551
552
553
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
554
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
555
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
556
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
557
558
559
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
560
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
561
562
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
563
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
564
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
565
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
566
567
568
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
thomwolf's avatar
thomwolf committed
569
570
571
       "      (7): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
572
573
574
575
576
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
577
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
578
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
579
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
580
581
582
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
583
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
584
585
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
586
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
587
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
588
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
589
590
591
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
thomwolf's avatar
thomwolf committed
592
593
594
       "      (8): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
595
596
597
598
599
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
600
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
601
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
602
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
603
604
605
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
606
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
607
608
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
609
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
610
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
611
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
612
613
614
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
thomwolf's avatar
thomwolf committed
615
616
617
       "      (9): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
618
619
620
621
622
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
623
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
624
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
625
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
626
627
628
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
629
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
630
631
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
632
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
633
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
634
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
635
636
637
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
thomwolf's avatar
thomwolf committed
638
639
640
       "      (10): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
641
642
643
644
645
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
646
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
647
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
648
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
649
650
651
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
652
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
653
654
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
655
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
656
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
657
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
658
659
660
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
thomwolf's avatar
thomwolf committed
661
662
663
       "      (11): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
664
665
666
667
668
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
669
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
670
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
671
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
672
673
674
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
675
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
676
677
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
678
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
679
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
680
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
681
682
683
684
685
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
thomwolf's avatar
thomwolf committed
686
       "  (pooler): BertPooler(\n",
thomwolf's avatar
thomwolf committed
687
688
689
690
691
692
       "    (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "    (activation): Tanh()\n",
       "  )\n",
       ")"
      ]
     },
thomwolf's avatar
thomwolf committed
693
     "execution_count": 26,
thomwolf's avatar
thomwolf committed
694
695
696
697
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
698
   "source": [
thomwolf's avatar
thomwolf committed
699
    "device = torch.device(\"cpu\")\n",
thomwolf's avatar
thomwolf committed
700
    "model = ppb.BertModel.from_pretrained(init_checkpoint_pt)\n",
thomwolf's avatar
thomwolf committed
701
702
703
704
705
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
thomwolf's avatar
thomwolf committed
706
   "execution_count": 27,
thomwolf's avatar
thomwolf committed
707
708
   "metadata": {
    "ExecuteTime": {
thomwolf's avatar
thomwolf committed
709
710
     "end_time": "2018-11-15T15:21:26.963427Z",
     "start_time": "2018-11-15T15:21:26.922494Z"
thomwolf's avatar
thomwolf committed
711
712
713
714
715
716
717
718
    },
    "code_folding": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BertModel(\n",
thomwolf's avatar
thomwolf committed
719
       "  (embeddings): BertEmbeddings(\n",
thomwolf's avatar
thomwolf committed
720
721
722
       "    (word_embeddings): Embedding(30522, 768)\n",
       "    (position_embeddings): Embedding(512, 768)\n",
       "    (token_type_embeddings): Embedding(2, 768)\n",
thomwolf's avatar
thomwolf committed
723
       "    (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
724
725
       "    (dropout): Dropout(p=0.1)\n",
       "  )\n",
thomwolf's avatar
thomwolf committed
726
       "  (encoder): BertEncoder(\n",
thomwolf's avatar
thomwolf committed
727
       "    (layer): ModuleList(\n",
thomwolf's avatar
thomwolf committed
728
729
730
       "      (0): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
731
732
733
734
735
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
736
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
737
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
738
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
739
740
741
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
742
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
743
744
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
745
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
746
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
747
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
748
749
750
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
thomwolf's avatar
thomwolf committed
751
752
753
       "      (1): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
754
755
756
757
758
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
759
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
760
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
761
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
762
763
764
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
765
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
766
767
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
768
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
769
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
770
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
771
772
773
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
thomwolf's avatar
thomwolf committed
774
775
776
       "      (2): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
777
778
779
780
781
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
782
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
783
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
784
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
785
786
787
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
788
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
789
790
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
791
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
792
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
793
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
794
795
796
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
thomwolf's avatar
thomwolf committed
797
798
799
       "      (3): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
800
801
802
803
804
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
805
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
806
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
807
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
808
809
810
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
811
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
812
813
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
814
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
815
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
816
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
817
818
819
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
thomwolf's avatar
thomwolf committed
820
821
822
       "      (4): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
823
824
825
826
827
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
828
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
829
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
830
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
831
832
833
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
834
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
835
836
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
837
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
838
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
839
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
840
841
842
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
thomwolf's avatar
thomwolf committed
843
844
845
       "      (5): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
846
847
848
849
850
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
851
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
852
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
853
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
854
855
856
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
857
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
858
859
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
860
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
861
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
862
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
863
864
865
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
thomwolf's avatar
thomwolf committed
866
867
868
       "      (6): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
869
870
871
872
873
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
874
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
875
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
876
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
877
878
879
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
880
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
881
882
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
883
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
884
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
885
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
886
887
888
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
thomwolf's avatar
thomwolf committed
889
890
891
       "      (7): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
892
893
894
895
896
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
897
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
898
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
899
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
900
901
902
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
903
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
904
905
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
906
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
907
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
908
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
909
910
911
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
thomwolf's avatar
thomwolf committed
912
913
914
       "      (8): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
915
916
917
918
919
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
920
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
921
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
922
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
923
924
925
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
926
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
927
928
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
929
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
930
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
931
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
932
933
934
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
thomwolf's avatar
thomwolf committed
935
936
937
       "      (9): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
938
939
940
941
942
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
943
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
944
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
945
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
946
947
948
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
949
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
950
951
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
952
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
953
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
954
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
955
956
957
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
thomwolf's avatar
thomwolf committed
958
959
960
       "      (10): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
961
962
963
964
965
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
966
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
967
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
968
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
969
970
971
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
972
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
973
974
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
975
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
976
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
977
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
978
979
980
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
thomwolf's avatar
thomwolf committed
981
982
983
       "      (11): BertLayer(\n",
       "        (attention): BertAttention(\n",
       "          (self): BertSelfAttention(\n",
thomwolf's avatar
thomwolf committed
984
985
986
987
988
       "            (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
thomwolf's avatar
thomwolf committed
989
       "          (output): BertSelfOutput(\n",
thomwolf's avatar
thomwolf committed
990
       "            (dense): Linear(in_features=768, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
991
       "            (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
992
993
994
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
995
       "        (intermediate): BertIntermediate(\n",
thomwolf's avatar
thomwolf committed
996
997
       "          (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "        )\n",
thomwolf's avatar
thomwolf committed
998
       "        (output): BertOutput(\n",
thomwolf's avatar
thomwolf committed
999
       "          (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
thomwolf's avatar
thomwolf committed
1000
       "          (LayerNorm): BertLayerNorm()\n",
thomwolf's avatar
thomwolf committed
1001
1002
1003
1004
1005
       "          (dropout): Dropout(p=0.1)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "  )\n",
thomwolf's avatar
thomwolf committed
1006
       "  (pooler): BertPooler(\n",
thomwolf's avatar
thomwolf committed
1007
1008
1009
1010
1011
1012
       "    (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "    (activation): Tanh()\n",
       "  )\n",
       ")"
      ]
     },
thomwolf's avatar
thomwolf committed
1013
     "execution_count": 27,
thomwolf's avatar
thomwolf committed
1014
1015
1016
1017
1018
1019
1020
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)\n",
    "all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)\n",
thomwolf's avatar
thomwolf committed
1021
    "all_input_type_ids = torch.tensor([f.input_type_ids for f in features], dtype=torch.long)\n",
thomwolf's avatar
thomwolf committed
1022
1023
    "all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)\n",
    "\n",
thomwolf's avatar
thomwolf committed
1024
    "eval_data = TensorDataset(all_input_ids, all_input_mask, all_input_type_ids, all_example_index)\n",
thomwolf's avatar
thomwolf committed
1025
1026
1027
1028
    "eval_sampler = SequentialSampler(eval_data)\n",
    "eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=1)\n",
    "\n",
    "model.eval()"
1029
1030
1031
1032
   ]
  },
  {
   "cell_type": "code",
thomwolf's avatar
thomwolf committed
1033
   "execution_count": 28,
1034
1035
   "metadata": {
    "ExecuteTime": {
thomwolf's avatar
thomwolf committed
1036
1037
     "end_time": "2018-11-15T15:21:30.718724Z",
     "start_time": "2018-11-15T15:21:30.329205Z"
1038
1039
    }
   },
thomwolf's avatar
thomwolf committed
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[  101,  2040,  2001,  3958, 27227,  1029,   102,  3958, 27227,  2001,\n",
      "          1037, 13997, 11510,   102,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,\n",
      "             0,     0,     0,     0,     0,     0,     0,     0]])\n",
      "tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 0, 0, 0, 0, 0, 0]])\n",
      "tensor([0])\n",
      "layer 0 0\n",
      "layer 1 1\n",
      "layer 2 2\n",
      "layer 3 3\n",
      "layer 4 4\n",
      "layer 5 5\n",
      "layer 6 6\n",
      "layer 7 7\n",
      "layer 8 8\n",
      "layer 9 9\n",
      "layer 10 10\n",
      "layer 11 11\n"
     ]
    }
   ],
1080
   "source": [
thomwolf's avatar
thomwolf committed
1081
1082
    "layer_indexes = list(range(12))\n",
    "\n",
thomwolf's avatar
thomwolf committed
1083
    "pytorch_all_out = []\n",
thomwolf's avatar
thomwolf committed
1084
1085
1086
1087
    "for input_ids, input_mask, input_type_ids, example_indices in eval_dataloader:\n",
    "    print(input_ids)\n",
    "    print(input_mask)\n",
    "    print(example_indices)\n",
thomwolf's avatar
thomwolf committed
1088
    "    input_ids = input_ids.to(device)\n",
thomwolf's avatar
thomwolf committed
1089
    "    input_mask = input_mask.to(device)\n",
thomwolf's avatar
thomwolf committed
1090
    "\n",
thomwolf's avatar
thomwolf committed
1091
    "    all_encoder_layers, _ = model(input_ids, token_type_ids=input_type_ids, attention_mask=input_mask)\n",
thomwolf's avatar
thomwolf committed
1092
    "\n",
thomwolf's avatar
thomwolf committed
1093
    "    for b, example_index in enumerate(example_indices):\n",
thomwolf's avatar
thomwolf committed
1094
1095
1096
1097
1098
1099
    "        feature = features[example_index.item()]\n",
    "        unique_id = int(feature.unique_id)\n",
    "        # feature = unique_id_to_feature[unique_id]\n",
    "        output_json = collections.OrderedDict()\n",
    "        output_json[\"linex_index\"] = unique_id\n",
    "        all_out_features = []\n",
thomwolf's avatar
thomwolf committed
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
    "        # for (i, token) in enumerate(feature.tokens):\n",
    "        all_layers = []\n",
    "        for (j, layer_index) in enumerate(layer_indexes):\n",
    "            print(\"layer\", j, layer_index)\n",
    "            layer_output = all_encoder_layers[int(layer_index)].detach().cpu().numpy()\n",
    "            layer_output = layer_output[b]\n",
    "            layers = collections.OrderedDict()\n",
    "            layers[\"index\"] = layer_index\n",
    "            layer_output = layer_output\n",
    "            layers[\"values\"] = layer_output if not isinstance(layer_output, (int, float)) else [layer_output]\n",
    "            all_layers.append(layers)\n",
    "\n",
thomwolf's avatar
thomwolf committed
1112
1113
1114
1115
1116
1117
1118
1119
1120
    "            out_features = collections.OrderedDict()\n",
    "            out_features[\"layers\"] = all_layers\n",
    "            all_out_features.append(out_features)\n",
    "        output_json[\"features\"] = all_out_features\n",
    "        pytorch_all_out.append(output_json)"
   ]
  },
  {
   "cell_type": "code",
thomwolf's avatar
thomwolf committed
1121
   "execution_count": 29,
thomwolf's avatar
thomwolf committed
1122
1123
   "metadata": {
    "ExecuteTime": {
thomwolf's avatar
thomwolf committed
1124
1125
     "end_time": "2018-11-15T15:21:35.703615Z",
     "start_time": "2018-11-15T15:21:35.666150Z"
thomwolf's avatar
thomwolf committed
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n",
      "odict_keys(['linex_index', 'features'])\n",
thomwolf's avatar
thomwolf committed
1136
1137
1138
      "number of tokens 1\n",
      "number of layers 12\n",
      "hidden_size 128\n"
thomwolf's avatar
thomwolf committed
1139
     ]
thomwolf's avatar
thomwolf committed
1140
1141
1142
1143
1144
1145
1146
    },
    {
     "data": {
      "text/plain": [
       "(128, 768)"
      ]
     },
thomwolf's avatar
thomwolf committed
1147
     "execution_count": 29,
thomwolf's avatar
thomwolf committed
1148
1149
     "metadata": {},
     "output_type": "execute_result"
thomwolf's avatar
thomwolf committed
1150
1151
1152
1153
1154
1155
    }
   ],
   "source": [
    "print(len(pytorch_all_out))\n",
    "print(len(pytorch_all_out[0]))\n",
    "print(pytorch_all_out[0].keys())\n",
thomwolf's avatar
thomwolf committed
1156
1157
1158
1159
    "print(\"number of tokens\", len(pytorch_all_out))\n",
    "print(\"number of layers\", len(pytorch_all_out[0]['features'][0]['layers']))\n",
    "print(\"hidden_size\", len(pytorch_all_out[0]['features'][0]['layers'][0]['values']))\n",
    "pytorch_all_out[0]['features'][0]['layers'][0]['values'].shape"
thomwolf's avatar
thomwolf committed
1160
1161
1162
1163
   ]
  },
  {
   "cell_type": "code",
thomwolf's avatar
thomwolf committed
1164
   "execution_count": 30,
thomwolf's avatar
thomwolf committed
1165
1166
   "metadata": {
    "ExecuteTime": {
thomwolf's avatar
thomwolf committed
1167
1168
     "end_time": "2018-11-15T15:21:36.999073Z",
     "start_time": "2018-11-15T15:21:36.966762Z"
thomwolf's avatar
thomwolf committed
1169
1170
1171
1172
    }
   },
   "outputs": [
    {
thomwolf's avatar
thomwolf committed
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(128, 768)\n",
      "(128, 768)\n"
     ]
    }
   ],
   "source": [
    "pytorch_outputs = list(pytorch_all_out[0]['features'][0]['layers'][t]['values'] for t in layer_indexes)\n",
    "print(pytorch_outputs[0].shape)\n",
    "print(pytorch_outputs[1].shape)"
   ]
  },
  {
   "cell_type": "code",
thomwolf's avatar
thomwolf committed
1189
   "execution_count": 31,
thomwolf's avatar
thomwolf committed
1190
1191
   "metadata": {
    "ExecuteTime": {
thomwolf's avatar
thomwolf committed
1192
1193
     "end_time": "2018-11-15T15:21:37.936522Z",
     "start_time": "2018-11-15T15:21:37.905269Z"
thomwolf's avatar
thomwolf committed
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(128, 768)\n",
      "(128, 768)\n"
     ]
thomwolf's avatar
thomwolf committed
1204
1205
1206
    }
   ],
   "source": [
thomwolf's avatar
thomwolf committed
1207
1208
    "print(tensorflow_outputs[0].shape)\n",
    "print(tensorflow_outputs[1].shape)"
thomwolf's avatar
thomwolf committed
1209
1210
   ]
  },
VictorSanh's avatar
VictorSanh committed
1211
1212
1213
1214
1215
1216
1217
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3/ Comparing the standard deviation on the last layer of both models"
   ]
  },
thomwolf's avatar
thomwolf committed
1218
1219
  {
   "cell_type": "code",
thomwolf's avatar
thomwolf committed
1220
   "execution_count": 32,
thomwolf's avatar
thomwolf committed
1221
1222
   "metadata": {
    "ExecuteTime": {
thomwolf's avatar
thomwolf committed
1223
1224
     "end_time": "2018-11-15T15:21:39.437137Z",
     "start_time": "2018-11-15T15:21:39.406150Z"
thomwolf's avatar
thomwolf committed
1225
1226
1227
1228
1229
1230
1231
1232
1233
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
thomwolf's avatar
thomwolf committed
1234
   "execution_count": 33,
thomwolf's avatar
thomwolf committed
1235
1236
   "metadata": {
    "ExecuteTime": {
thomwolf's avatar
thomwolf committed
1237
1238
     "end_time": "2018-11-15T15:21:40.181870Z",
     "start_time": "2018-11-15T15:21:40.137023Z"
thomwolf's avatar
thomwolf committed
1239
1240
1241
    }
   },
   "outputs": [
thomwolf's avatar
thomwolf committed
1242
1243
1244
1245
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
      "shape tensorflow layer, shape pytorch layer, standard deviation\n",
      "((128, 768), (128, 768), 1.5258875e-07)\n",
      "((128, 768), (128, 768), 2.342731e-07)\n",
      "((128, 768), (128, 768), 2.801949e-07)\n",
      "((128, 768), (128, 768), 3.5904986e-07)\n",
      "((128, 768), (128, 768), 4.2842768e-07)\n",
      "((128, 768), (128, 768), 5.127951e-07)\n",
      "((128, 768), (128, 768), 6.14668e-07)\n",
      "((128, 768), (128, 768), 7.063922e-07)\n",
      "((128, 768), (128, 768), 7.906173e-07)\n",
      "((128, 768), (128, 768), 8.475192e-07)\n",
      "((128, 768), (128, 768), 8.975489e-07)\n",
      "((128, 768), (128, 768), 4.1671223e-07)\n"
thomwolf's avatar
thomwolf committed
1259
     ]
thomwolf's avatar
thomwolf committed
1260
1261
1262
    }
   ],
   "source": [
1263
1264
1265
1266
    "print('shape tensorflow layer, shape pytorch layer, standard deviation')\n",
    "print('\\n'.join(list(str((np.array(tensorflow_outputs[i]).shape,\n",
    "                          np.array(pytorch_outputs[i]).shape, \n",
    "                          np.sqrt(np.mean((np.array(tensorflow_outputs[i]) - np.array(pytorch_outputs[i]))**2.0)))) for i in range(12))))"
1267
   ]
1268
1269
1270
1271
1272
1273
1274
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
1275
1276
1277
1278
1279
  }
 ],
 "metadata": {
  "hide_input": false,
  "kernelspec": {
1280
   "display_name": "Python [default]",
1281
   "language": "python",
VictorSanh's avatar
VictorSanh committed
1282
   "name": "python3"
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
1294
   "version": "3.6.7"
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
  },
  "toc": {
   "colors": {
    "hover_highlight": "#DAA520",
    "running_highlight": "#FF0000",
    "selected_highlight": "#FFD700"
   },
   "moveMenuLeft": true,
   "nav_menu": {
    "height": "48px",
    "width": "252px"
   },
   "navigate_menu": true,
   "number_sections": true,
   "sideBar": true,
   "threshold": 4,
   "toc_cell": false,
   "toc_section_display": "block",
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}