Commit 907d3569 authored by thomwolf's avatar thomwolf
Browse files

cleaning up SQuAD notebook - more explanation - fixing error

parent 1a5bbd83
......@@ -22,8 +22,8 @@
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:01.181879Z",
"start_time": "2018-11-05T13:58:01.167184Z"
"end_time": "2018-11-06T10:11:33.636911Z",
"start_time": "2018-11-06T10:11:33.623091Z"
}
},
"outputs": [],
......@@ -44,8 +44,8 @@
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:01.196873Z",
"start_time": "2018-11-05T13:58:01.184052Z"
"end_time": "2018-11-06T10:11:33.651792Z",
"start_time": "2018-11-06T10:11:33.638984Z"
}
},
"outputs": [],
......@@ -62,6 +62,7 @@
"outside_pos = max_seq_length + 10\n",
"doc_stride = 128\n",
"max_query_length = 64\n",
"max_answer_length = 30\n",
"output_dir = \"/tmp/squad_base/\"\n",
"learning_rate = 3e-5"
]
......@@ -71,8 +72,8 @@
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:02.284694Z",
"start_time": "2018-11-05T13:58:01.198324Z"
"end_time": "2018-11-06T10:11:35.165788Z",
"start_time": "2018-11-06T10:11:33.653401Z"
}
},
"outputs": [],
......@@ -98,8 +99,8 @@
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:04.517331Z",
"start_time": "2018-11-05T13:58:02.287351Z"
"end_time": "2018-11-06T10:11:37.494391Z",
"start_time": "2018-11-06T10:11:35.168615Z"
}
},
"outputs": [
......@@ -390,14 +391,14 @@
"INFO:tensorflow:*** Example ***\n",
"INFO:tensorflow:unique_id: 1000000016\n",
"INFO:tensorflow:example_index: 16\n",
"INFO:tensorflow:doc_span_index: 0\n"
"INFO:tensorflow:doc_span_index: 0\n",
"INFO:tensorflow:tokens: [CLS] in what year was the college of engineering at notre dame formed ? [SEP] the college of engineering was established in 1920 , however , early courses in civil and mechanical engineering were a part of the college of science since the 1870s . today the college , housed in the fitzpatrick , cu ##shing , and st ##ins ##on - re ##mic ##k halls of engineering , includes five departments of study – aerospace and mechanical engineering , chemical and bio ##mo ##le ##cular engineering , civil engineering and geological sciences , computer science and engineering , and electrical engineering – with eight b . s . degrees offered . additionally , the college offers five - year dual degree programs with the colleges of arts and letters and of business awarding additional b . a . and master of business administration ( mba ) degrees , respectively . [SEP]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:tokens: [CLS] in what year was the college of engineering at notre dame formed ? [SEP] the college of engineering was established in 1920 , however , early courses in civil and mechanical engineering were a part of the college of science since the 1870s . today the college , housed in the fitzpatrick , cu ##shing , and st ##ins ##on - re ##mic ##k halls of engineering , includes five departments of study – aerospace and mechanical engineering , chemical and bio ##mo ##le ##cular engineering , civil engineering and geological sciences , computer science and engineering , and electrical engineering – with eight b . s . degrees offered . additionally , the college offers five - year dual degree programs with the colleges of arts and letters and of business awarding additional b . a . and master of business administration ( mba ) degrees , respectively . [SEP]\n",
"INFO:tensorflow:token_to_orig_map: 15:0 16:1 17:2 18:3 19:4 20:5 21:6 22:7 23:7 24:8 25:8 26:9 27:10 28:11 29:12 30:13 31:14 32:15 33:16 34:17 35:18 36:19 37:20 38:21 39:22 40:23 41:24 42:25 43:26 44:26 45:27 46:28 47:29 48:29 49:30 50:31 51:32 52:33 53:33 54:34 55:34 56:34 57:35 58:36 59:36 60:36 61:36 62:36 63:36 64:36 65:37 66:38 67:39 68:39 69:40 70:41 71:42 72:43 73:44 74:45 75:46 76:47 77:48 78:49 79:49 80:50 81:51 82:52 83:52 84:52 85:52 86:53 87:53 88:54 89:55 90:56 91:57 92:58 93:58 94:59 95:60 96:61 97:62 98:62 99:63 100:64 101:65 102:66 103:67 104:68 105:69 106:69 107:69 108:69 109:70 110:71 111:71 112:72 113:72 114:73 115:74 116:75 117:76 118:76 119:76 120:77 121:78 122:79 123:80 124:81 125:82 126:83 127:84 128:85 129:86 130:87 131:88 132:89 133:90 134:91 135:92 136:92 137:92 138:92 139:93 140:94 141:95 142:96 143:97 144:98 145:98 146:98 147:99 148:99 149:100 150:100\n",
"INFO:tensorflow:token_is_max_context: 15:True 16:True 17:True 18:True 19:True 20:True 21:True 22:True 23:True 24:True 25:True 26:True 27:True 28:True 29:True 30:True 31:True 32:True 33:True 34:True 35:True 36:True 37:True 38:True 39:True 40:True 41:True 42:True 43:True 44:True 45:True 46:True 47:True 48:True 49:True 50:True 51:True 52:True 53:True 54:True 55:True 56:True 57:True 58:True 59:True 60:True 61:True 62:True 63:True 64:True 65:True 66:True 67:True 68:True 69:True 70:True 71:True 72:True 73:True 74:True 75:True 76:True 77:True 78:True 79:True 80:True 81:True 82:True 83:True 84:True 85:True 86:True 87:True 88:True 89:True 90:True 91:True 92:True 93:True 94:True 95:True 96:True 97:True 98:True 99:True 100:True 101:True 102:True 103:True 104:True 105:True 106:True 107:True 108:True 109:True 110:True 111:True 112:True 113:True 114:True 115:True 116:True 117:True 118:True 119:True 120:True 121:True 122:True 123:True 124:True 125:True 126:True 127:True 128:True 129:True 130:True 131:True 132:True 133:True 134:True 135:True 136:True 137:True 138:True 139:True 140:True 141:True 142:True 143:True 144:True 145:True 146:True 147:True 148:True 149:True 150:True\n",
"INFO:tensorflow:input_ids: 101 1999 2054 2095 2001 1996 2267 1997 3330 2012 10289 8214 2719 1029 102 1996 2267 1997 3330 2001 2511 1999 4444 1010 2174 1010 2220 5352 1999 2942 1998 6228 3330 2020 1037 2112 1997 1996 2267 1997 2671 2144 1996 14896 1012 2651 1996 2267 1010 7431 1999 1996 26249 1010 12731 12227 1010 1998 2358 7076 2239 1011 2128 7712 2243 9873 1997 3330 1010 2950 2274 7640 1997 2817 1516 13395 1998 6228 3330 1010 5072 1998 16012 5302 2571 15431 3330 1010 2942 3330 1998 9843 4163 1010 3274 2671 1998 3330 1010 1998 5992 3330 1516 2007 2809 1038 1012 1055 1012 5445 3253 1012 5678 1010 1996 2267 4107 2274 1011 2095 7037 3014 3454 2007 1996 6667 1997 2840 1998 4144 1998 1997 2449 21467 3176 1038 1012 1037 1012 1998 3040 1997 2449 3447 1006 15038 1007 5445 1010 4414 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
......@@ -412,14 +413,14 @@
"INFO:tensorflow:doc_span_index: 0\n",
"INFO:tensorflow:tokens: [CLS] before the creation of the college of engineering similar studies were carried out at which notre dame college ? [SEP] the college of engineering was established in 1920 , however , early courses in civil and mechanical engineering were a part of the college of science since the 1870s . today the college , housed in the fitzpatrick , cu ##shing , and st ##ins ##on - re ##mic ##k halls of engineering , includes five departments of study – aerospace and mechanical engineering , chemical and bio ##mo ##le ##cular engineering , civil engineering and geological sciences , computer science and engineering , and electrical engineering – with eight b . s . degrees offered . additionally , the college offers five - year dual degree programs with the colleges of arts and letters and of business awarding additional b . a . and master of business administration ( mba ) degrees , respectively . [SEP]\n",
"INFO:tensorflow:token_to_orig_map: 21:0 22:1 23:2 24:3 25:4 26:5 27:6 28:7 29:7 30:8 31:8 32:9 33:10 34:11 35:12 36:13 37:14 38:15 39:16 40:17 41:18 42:19 43:20 44:21 45:22 46:23 47:24 48:25 49:26 50:26 51:27 52:28 53:29 54:29 55:30 56:31 57:32 58:33 59:33 60:34 61:34 62:34 63:35 64:36 65:36 66:36 67:36 68:36 69:36 70:36 71:37 72:38 73:39 74:39 75:40 76:41 77:42 78:43 79:44 80:45 81:46 82:47 83:48 84:49 85:49 86:50 87:51 88:52 89:52 90:52 91:52 92:53 93:53 94:54 95:55 96:56 97:57 98:58 99:58 100:59 101:60 102:61 103:62 104:62 105:63 106:64 107:65 108:66 109:67 110:68 111:69 112:69 113:69 114:69 115:70 116:71 117:71 118:72 119:72 120:73 121:74 122:75 123:76 124:76 125:76 126:77 127:78 128:79 129:80 130:81 131:82 132:83 133:84 134:85 135:86 136:87 137:88 138:89 139:90 140:91 141:92 142:92 143:92 144:92 145:93 146:94 147:95 148:96 149:97 150:98 151:98 152:98 153:99 154:99 155:100 156:100\n",
"INFO:tensorflow:token_is_max_context: 21:True 22:True 23:True 24:True 25:True 26:True 27:True 28:True 29:True 30:True 31:True 32:True 33:True 34:True 35:True 36:True 37:True 38:True 39:True 40:True 41:True 42:True 43:True 44:True 45:True 46:True 47:True 48:True 49:True 50:True 51:True 52:True 53:True 54:True 55:True 56:True 57:True 58:True 59:True 60:True 61:True 62:True 63:True 64:True 65:True 66:True 67:True 68:True 69:True 70:True 71:True 72:True 73:True 74:True 75:True 76:True 77:True 78:True 79:True 80:True 81:True 82:True 83:True 84:True 85:True 86:True 87:True 88:True 89:True 90:True 91:True 92:True 93:True 94:True 95:True 96:True 97:True 98:True 99:True 100:True 101:True 102:True 103:True 104:True 105:True 106:True 107:True 108:True 109:True 110:True 111:True 112:True 113:True 114:True 115:True 116:True 117:True 118:True 119:True 120:True 121:True 122:True 123:True 124:True 125:True 126:True 127:True 128:True 129:True 130:True 131:True 132:True 133:True 134:True 135:True 136:True 137:True 138:True 139:True 140:True 141:True 142:True 143:True 144:True 145:True 146:True 147:True 148:True 149:True 150:True 151:True 152:True 153:True 154:True 155:True 156:True\n"
"INFO:tensorflow:token_is_max_context: 21:True 22:True 23:True 24:True 25:True 26:True 27:True 28:True 29:True 30:True 31:True 32:True 33:True 34:True 35:True 36:True 37:True 38:True 39:True 40:True 41:True 42:True 43:True 44:True 45:True 46:True 47:True 48:True 49:True 50:True 51:True 52:True 53:True 54:True 55:True 56:True 57:True 58:True 59:True 60:True 61:True 62:True 63:True 64:True 65:True 66:True 67:True 68:True 69:True 70:True 71:True 72:True 73:True 74:True 75:True 76:True 77:True 78:True 79:True 80:True 81:True 82:True 83:True 84:True 85:True 86:True 87:True 88:True 89:True 90:True 91:True 92:True 93:True 94:True 95:True 96:True 97:True 98:True 99:True 100:True 101:True 102:True 103:True 104:True 105:True 106:True 107:True 108:True 109:True 110:True 111:True 112:True 113:True 114:True 115:True 116:True 117:True 118:True 119:True 120:True 121:True 122:True 123:True 124:True 125:True 126:True 127:True 128:True 129:True 130:True 131:True 132:True 133:True 134:True 135:True 136:True 137:True 138:True 139:True 140:True 141:True 142:True 143:True 144:True 145:True 146:True 147:True 148:True 149:True 150:True 151:True 152:True 153:True 154:True 155:True 156:True\n",
"INFO:tensorflow:input_ids: 101 2077 1996 4325 1997 1996 2267 1997 3330 2714 2913 2020 3344 2041 2012 2029 10289 8214 2267 1029 102 1996 2267 1997 3330 2001 2511 1999 4444 1010 2174 1010 2220 5352 1999 2942 1998 6228 3330 2020 1037 2112 1997 1996 2267 1997 2671 2144 1996 14896 1012 2651 1996 2267 1010 7431 1999 1996 26249 1010 12731 12227 1010 1998 2358 7076 2239 1011 2128 7712 2243 9873 1997 3330 1010 2950 2274 7640 1997 2817 1516 13395 1998 6228 3330 1010 5072 1998 16012 5302 2571 15431 3330 1010 2942 3330 1998 9843 4163 1010 3274 2671 1998 3330 1010 1998 5992 3330 1516 2007 2809 1038 1012 1055 1012 5445 3253 1012 5678 1010 1996 2267 4107 2274 1011 2095 7037 3014 3454 2007 1996 6667 1997 2840 1998 4144 1998 1997 2449 21467 3176 1038 1012 1037 1012 1998 3040 1997 2449 3447 1006 15038 1007 5445 1010 4414 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:input_ids: 101 2077 1996 4325 1997 1996 2267 1997 3330 2714 2913 2020 3344 2041 2012 2029 10289 8214 2267 1029 102 1996 2267 1997 3330 2001 2511 1999 4444 1010 2174 1010 2220 5352 1999 2942 1998 6228 3330 2020 1037 2112 1997 1996 2267 1997 2671 2144 1996 14896 1012 2651 1996 2267 1010 7431 1999 1996 26249 1010 12731 12227 1010 1998 2358 7076 2239 1011 2128 7712 2243 9873 1997 3330 1010 2950 2274 7640 1997 2817 1516 13395 1998 6228 3330 1010 5072 1998 16012 5302 2571 15431 3330 1010 2942 3330 1998 9843 4163 1010 3274 2671 1998 3330 1010 1998 5992 3330 1516 2007 2809 1038 1012 1055 1012 5445 3253 1012 5678 1010 1996 2267 4107 2274 1011 2095 7037 3014 3454 2007 1996 6667 1997 2840 1998 4144 1998 1997 2449 21467 3176 1038 1012 1037 1012 1998 3040 1997 2449 3447 1006 15038 1007 5445 1010 4414 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
"INFO:tensorflow:input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
"INFO:tensorflow:segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
"INFO:tensorflow:start_position: 43\n",
......@@ -434,13 +435,7 @@
"INFO:tensorflow:token_is_max_context: 19:True 20:True 21:True 22:True 23:True 24:True 25:True 26:True 27:True 28:True 29:True 30:True 31:True 32:True 33:True 34:True 35:True 36:True 37:True 38:True 39:True 40:True 41:True 42:True 43:True 44:True 45:True 46:True 47:True 48:True 49:True 50:True 51:True 52:True 53:True 54:True 55:True 56:True 57:True 58:True 59:True 60:True 61:True 62:True 63:True 64:True 65:True 66:True 67:True 68:True 69:True 70:True 71:True 72:True 73:True 74:True 75:True 76:True 77:True 78:True 79:True 80:True 81:True 82:True 83:True 84:True 85:True 86:True 87:True 88:True 89:True 90:True 91:True 92:True 93:True 94:True 95:True 96:True 97:True 98:True 99:True 100:True 101:True 102:True 103:True 104:True 105:True 106:True 107:True 108:True 109:True 110:True 111:True 112:True 113:True 114:True 115:True 116:True 117:True 118:True 119:True 120:True 121:True 122:True 123:True 124:True 125:True 126:True 127:True 128:True 129:True 130:True 131:True 132:True 133:True 134:True 135:True 136:True 137:True 138:True 139:True 140:True 141:True 142:True 143:True 144:True 145:True 146:True 147:True 148:True 149:True 150:True 151:True 152:True 153:True 154:True\n",
"INFO:tensorflow:input_ids: 101 2129 2116 7640 2024 2306 1996 2358 7076 2239 1011 2128 7712 2243 2534 1997 3330 1029 102 1996 2267 1997 3330 2001 2511 1999 4444 1010 2174 1010 2220 5352 1999 2942 1998 6228 3330 2020 1037 2112 1997 1996 2267 1997 2671 2144 1996 14896 1012 2651 1996 2267 1010 7431 1999 1996 26249 1010 12731 12227 1010 1998 2358 7076 2239 1011 2128 7712 2243 9873 1997 3330 1010 2950 2274 7640 1997 2817 1516 13395 1998 6228 3330 1010 5072 1998 16012 5302 2571 15431 3330 1010 2942 3330 1998 9843 4163 1010 3274 2671 1998 3330 1010 1998 5992 3330 1516 2007 2809 1038 1012 1055 1012 5445 3253 1012 5678 1010 1996 2267 4107 2274 1011 2095 7037 3014 3454 2007 1996 6667 1997 2840 1998 4144 1998 1997 2449 21467 3176 1038 1012 1037 1012 1998 3040 1997 2449 3447 1006 15038 1007 5445 1010 4414 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
"INFO:tensorflow:input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
"INFO:tensorflow:segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
"INFO:tensorflow:start_position: 74\n",
"INFO:tensorflow:end_position: 74\n",
"INFO:tensorflow:answer: five\n",
......@@ -448,7 +443,13 @@
"INFO:tensorflow:unique_id: 1000000019\n",
"INFO:tensorflow:example_index: 19\n",
"INFO:tensorflow:doc_span_index: 0\n",
"INFO:tensorflow:tokens: [CLS] the college of science began to offer civil engineering courses beginning at what time at notre dame ? [SEP] the college of engineering was established in 1920 , however , early courses in civil and mechanical engineering were a part of the college of science since the 1870s . today the college , housed in the fitzpatrick , cu ##shing , and st ##ins ##on - re ##mic ##k halls of engineering , includes five departments of study – aerospace and mechanical engineering , chemical and bio ##mo ##le ##cular engineering , civil engineering and geological sciences , computer science and engineering , and electrical engineering – with eight b . s . degrees offered . additionally , the college offers five - year dual degree programs with the colleges of arts and letters and of business awarding additional b . a . and master of business administration ( mba ) degrees , respectively . [SEP]\n",
"INFO:tensorflow:tokens: [CLS] the college of science began to offer civil engineering courses beginning at what time at notre dame ? [SEP] the college of engineering was established in 1920 , however , early courses in civil and mechanical engineering were a part of the college of science since the 1870s . today the college , housed in the fitzpatrick , cu ##shing , and st ##ins ##on - re ##mic ##k halls of engineering , includes five departments of study – aerospace and mechanical engineering , chemical and bio ##mo ##le ##cular engineering , civil engineering and geological sciences , computer science and engineering , and electrical engineering – with eight b . s . degrees offered . additionally , the college offers five - year dual degree programs with the colleges of arts and letters and of business awarding additional b . a . and master of business administration ( mba ) degrees , respectively . [SEP]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:token_to_orig_map: 20:0 21:1 22:2 23:3 24:4 25:5 26:6 27:7 28:7 29:8 30:8 31:9 32:10 33:11 34:12 35:13 36:14 37:15 38:16 39:17 40:18 41:19 42:20 43:21 44:22 45:23 46:24 47:25 48:26 49:26 50:27 51:28 52:29 53:29 54:30 55:31 56:32 57:33 58:33 59:34 60:34 61:34 62:35 63:36 64:36 65:36 66:36 67:36 68:36 69:36 70:37 71:38 72:39 73:39 74:40 75:41 76:42 77:43 78:44 79:45 80:46 81:47 82:48 83:49 84:49 85:50 86:51 87:52 88:52 89:52 90:52 91:53 92:53 93:54 94:55 95:56 96:57 97:58 98:58 99:59 100:60 101:61 102:62 103:62 104:63 105:64 106:65 107:66 108:67 109:68 110:69 111:69 112:69 113:69 114:70 115:71 116:71 117:72 118:72 119:73 120:74 121:75 122:76 123:76 124:76 125:77 126:78 127:79 128:80 129:81 130:82 131:83 132:84 133:85 134:86 135:87 136:88 137:89 138:90 139:91 140:92 141:92 142:92 143:92 144:93 145:94 146:95 147:96 148:97 149:98 150:98 151:98 152:99 153:99 154:100 155:100\n",
"INFO:tensorflow:token_is_max_context: 20:True 21:True 22:True 23:True 24:True 25:True 26:True 27:True 28:True 29:True 30:True 31:True 32:True 33:True 34:True 35:True 36:True 37:True 38:True 39:True 40:True 41:True 42:True 43:True 44:True 45:True 46:True 47:True 48:True 49:True 50:True 51:True 52:True 53:True 54:True 55:True 56:True 57:True 58:True 59:True 60:True 61:True 62:True 63:True 64:True 65:True 66:True 67:True 68:True 69:True 70:True 71:True 72:True 73:True 74:True 75:True 76:True 77:True 78:True 79:True 80:True 81:True 82:True 83:True 84:True 85:True 86:True 87:True 88:True 89:True 90:True 91:True 92:True 93:True 94:True 95:True 96:True 97:True 98:True 99:True 100:True 101:True 102:True 103:True 104:True 105:True 106:True 107:True 108:True 109:True 110:True 111:True 112:True 113:True 114:True 115:True 116:True 117:True 118:True 119:True 120:True 121:True 122:True 123:True 124:True 125:True 126:True 127:True 128:True 129:True 130:True 131:True 132:True 133:True 134:True 135:True 136:True 137:True 138:True 139:True 140:True 141:True 142:True 143:True 144:True 145:True 146:True 147:True 148:True 149:True 150:True 151:True 152:True 153:True 154:True 155:True\n",
"INFO:tensorflow:input_ids: 101 1996 2267 1997 2671 2211 2000 3749 2942 3330 5352 2927 2012 2054 2051 2012 10289 8214 1029 102 1996 2267 1997 3330 2001 2511 1999 4444 1010 2174 1010 2220 5352 1999 2942 1998 6228 3330 2020 1037 2112 1997 1996 2267 1997 2671 2144 1996 14896 1012 2651 1996 2267 1010 7431 1999 1996 26249 1010 12731 12227 1010 1998 2358 7076 2239 1011 2128 7712 2243 9873 1997 3330 1010 2950 2274 7640 1997 2817 1516 13395 1998 6228 3330 1010 5072 1998 16012 5302 2571 15431 3330 1010 2942 3330 1998 9843 4163 1010 3274 2671 1998 3330 1010 1998 5992 3330 1516 2007 2809 1038 1012 1055 1012 5445 3253 1012 5678 1010 1996 2267 4107 2274 1011 2095 7037 3014 3454 2007 1996 6667 1997 2840 1998 4144 1998 1997 2449 21467 3176 1038 1012 1037 1012 1998 3040 1997 2449 3447 1006 15038 1007 5445 1010 4414 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
......@@ -487,8 +488,8 @@
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:04.566557Z",
"start_time": "2018-11-05T13:58:04.522405Z"
"end_time": "2018-11-06T10:11:37.525632Z",
"start_time": "2018-11-06T10:11:37.498695Z"
}
},
"outputs": [],
......@@ -503,8 +504,8 @@
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:04.602691Z",
"start_time": "2018-11-05T13:58:04.568978Z"
"end_time": "2018-11-06T10:11:37.558325Z",
"start_time": "2018-11-06T10:11:37.527972Z"
}
},
"outputs": [],
......@@ -578,8 +579,8 @@
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:04.648232Z",
"start_time": "2018-11-05T13:58:04.604691Z"
"end_time": "2018-11-06T10:11:37.601666Z",
"start_time": "2018-11-06T10:11:37.560082Z"
}
},
"outputs": [],
......@@ -709,8 +710,8 @@
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:08.389636Z",
"start_time": "2018-11-05T13:58:04.649873Z"
"end_time": "2018-11-06T10:11:41.104542Z",
"start_time": "2018-11-06T10:11:37.603474Z"
}
},
"outputs": [
......@@ -718,14 +719,14 @@
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:Estimator's model_fn (<function model_fn_builder.<locals>.model_fn at 0x12d72fe18>) includes params argument, but params are not passed to Estimator.\n",
"WARNING:tensorflow:Estimator's model_fn (<function model_fn_builder.<locals>.model_fn at 0x120df3f28>) includes params argument, but params are not passed to Estimator.\n",
"INFO:tensorflow:Using config: {'_model_dir': '/tmp/squad_base/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 1000, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true\n",
"graph_options {\n",
" rewrite_options {\n",
" meta_optimizer_iterations: ONE\n",
" }\n",
"}\n",
", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x12c676a20>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=1000, num_shards=8, num_cores_per_replica=None, per_host_input_for_training=3, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None), '_cluster': None}\n",
", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x11fd09630>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=1000, num_shards=8, num_cores_per_replica=None, per_host_input_for_training=3, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None), '_cluster': None}\n",
"INFO:tensorflow:_TPUContext: eval_on_tpu True\n",
"WARNING:tensorflow:eval_on_tpu ignored because use_tpu is False.\n"
]
......@@ -770,8 +771,8 @@
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:14.923839Z",
"start_time": "2018-11-05T13:58:08.391524Z"
"end_time": "2018-11-06T10:11:47.857601Z",
"start_time": "2018-11-06T10:11:41.106219Z"
}
},
"outputs": [
......@@ -1013,6 +1014,7 @@
],
"source": [
"tensorflow_all_out = []\n",
"tensorflow_all_results = []\n",
"for result in estimator.predict(predict_input_fn, yield_single_examples=True):\n",
" unique_id = int(result[\"unique_ids\"])\n",
" eval_feature = eval_unique_id_to_feature[unique_id]\n",
......@@ -1031,6 +1033,10 @@
" output_json[\"start_loss\"] = [round(float(x), 6) for x in start_loss.flat]\n",
" output_json[\"end_loss\"] = [round(float(x), 6) for x in end_loss.flat]\n",
" tensorflow_all_out.append(output_json)\n",
" tensorflow_all_results.append(RawResult(\n",
" unique_id=unique_id,\n",
" start_logits=start_logits,\n",
" end_logits=end_logits))\n",
" break"
]
},
......@@ -1039,8 +1045,223 @@
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:14.958200Z",
"start_time": "2018-11-05T13:58:14.925742Z"
"end_time": "2018-11-06T10:11:47.912836Z",
"start_time": "2018-11-06T10:11:47.859679Z"
},
"code_folding": []
},
"outputs": [],
"source": [
"def _get_best_indexes(logits, n_best_size):\n",
" \"\"\"Get the n-best logits from a list.\"\"\"\n",
" index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)\n",
"\n",
" best_indexes = []\n",
" for i in range(len(index_and_score)):\n",
" if i >= n_best_size:\n",
" break\n",
" best_indexes.append(index_and_score[i][0])\n",
" return best_indexes\n",
"\n",
"def _compute_softmax(scores):\n",
" \"\"\"Compute softmax probability over raw logits.\"\"\"\n",
" if not scores:\n",
" return []\n",
"\n",
" max_score = None\n",
" for score in scores:\n",
" if max_score is None or score > max_score:\n",
" max_score = score\n",
"\n",
" exp_scores = []\n",
" total_sum = 0.0\n",
" for score in scores:\n",
" x = math.exp(score - max_score)\n",
" exp_scores.append(x)\n",
" total_sum += x\n",
"\n",
" probs = []\n",
" for score in exp_scores:\n",
" probs.append(score / total_sum)\n",
" return probs\n",
"\n",
"\n",
"def compute_predictions(all_examples, all_features, all_results, n_best_size,\n",
" max_answer_length, do_lower_case):\n",
" \"\"\"Compute final predictions.\"\"\"\n",
" example_index_to_features = collections.defaultdict(list)\n",
" for feature in all_features:\n",
" example_index_to_features[feature.example_index].append(feature)\n",
"\n",
" unique_id_to_result = {}\n",
" for result in all_results:\n",
" unique_id_to_result[result.unique_id] = result\n",
"\n",
" _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name\n",
" \"PrelimPrediction\",\n",
" [\"feature_index\", \"start_index\", \"end_index\", \"start_logit\", \"end_logit\"])\n",
"\n",
" all_predictions = collections.OrderedDict()\n",
" all_nbest_json = collections.OrderedDict()\n",
" for (example_index, example) in enumerate(all_examples):\n",
" features = example_index_to_features[example_index]\n",
"\n",
" prelim_predictions = []\n",
" for (feature_index, feature) in enumerate(features):\n",
" result = unique_id_to_result[feature.unique_id]\n",
"\n",
" start_indexes = _get_best_indexes(result.start_logits, n_best_size)\n",
" end_indexes = _get_best_indexes(result.end_logits, n_best_size)\n",
" for start_index in start_indexes:\n",
" for end_index in end_indexes:\n",
" # We could hypothetically create invalid predictions, e.g., predict\n",
" # that the start of the span is in the question. We throw out all\n",
" # invalid predictions.\n",
" if start_index >= len(feature.tokens):\n",
" continue\n",
" if end_index >= len(feature.tokens):\n",
" continue\n",
" if start_index not in feature.token_to_orig_map:\n",
" continue\n",
" if end_index not in feature.token_to_orig_map:\n",
" continue\n",
" if not feature.token_is_max_context.get(start_index, False):\n",
" continue\n",
" if end_index < start_index:\n",
" continue\n",
" length = end_index - start_index + 1\n",
" if length > max_answer_length:\n",
" continue\n",
" prelim_predictions.append(\n",
" _PrelimPrediction(\n",
" feature_index=feature_index,\n",
" start_index=start_index,\n",
" end_index=end_index,\n",
" start_logit=result.start_logits[start_index],\n",
" end_logit=result.end_logits[end_index]))\n",
"\n",
" prelim_predictions = sorted(\n",
" prelim_predictions,\n",
" key=lambda x: (x.start_logit + x.end_logit),\n",
" reverse=True)\n",
"\n",
" _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name\n",
" \"NbestPrediction\", [\"text\", \"start_logit\", \"end_logit\"])\n",
"\n",
" seen_predictions = {}\n",
" nbest = []\n",
" for pred in prelim_predictions:\n",
" if len(nbest) >= n_best_size:\n",
" break\n",
" feature = features[pred.feature_index]\n",
"\n",
" tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]\n",
" orig_doc_start = feature.token_to_orig_map[pred.start_index]\n",
" orig_doc_end = feature.token_to_orig_map[pred.end_index]\n",
" orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]\n",
" tok_text = \" \".join(tok_tokens)\n",
"\n",
" # De-tokenize WordPieces that have been split off.\n",
" tok_text = tok_text.replace(\" ##\", \"\")\n",
" tok_text = tok_text.replace(\"##\", \"\")\n",
"\n",
" # Clean whitespace\n",
" tok_text = tok_text.strip()\n",
" tok_text = \" \".join(tok_text.split())\n",
" orig_text = \" \".join(orig_tokens)\n",
"\n",
" final_text = get_final_text(tok_text, orig_text, do_lower_case)\n",
" if final_text in seen_predictions:\n",
" continue\n",
"\n",
" seen_predictions[final_text] = True\n",
" nbest.append(\n",
" _NbestPrediction(\n",
" text=final_text,\n",
" start_logit=pred.start_logit,\n",
" end_logit=pred.end_logit))\n",
"\n",
" # In very rare edge cases we could have no valid predictions. So we\n",
" # just create a nonce prediction in this case to avoid failure.\n",
" if not nbest:\n",
" nbest.append(\n",
" _NbestPrediction(text=\"empty\", start_logit=0.0, end_logit=0.0))\n",
"\n",
" assert len(nbest) >= 1\n",
"\n",
" total_scores = []\n",
" for entry in nbest:\n",
" total_scores.append(entry.start_logit + entry.end_logit)\n",
"\n",
" probs = _compute_softmax(total_scores)\n",
"\n",
" nbest_json = []\n",
" for (i, entry) in enumerate(nbest):\n",
" output = collections.OrderedDict()\n",
" output[\"text\"] = entry.text\n",
" output[\"probability\"] = probs[i]\n",
" output[\"start_logit\"] = entry.start_logit\n",
" output[\"end_logit\"] = entry.end_logit\n",
" nbest_json.append(output)\n",
"\n",
" assert len(nbest_json) >= 1\n",
"\n",
" all_predictions[example.qas_id] = nbest_json[0][\"text\"]\n",
" all_nbest_json[example.qas_id] = nbest_json\n",
"\n",
" return all_predictions, all_nbest_json"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-06T10:11:47.953205Z",
"start_time": "2018-11-06T10:11:47.914751Z"
}
},
"outputs": [],
"source": [
"all_predictions, all_nbest_json = compute_predictions(eval_examples[:1], eval_features[:1], tensorflow_all_results, 20, max_answer_length, True)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-06T10:11:47.994647Z",
"start_time": "2018-11-06T10:11:47.955015Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"OrderedDict([('5733be284776f41900661182',\n",
" [OrderedDict([('text', 'empty'),\n",
" ('probability', 1.0),\n",
" ('start_logit', 0.0),\n",
" ('end_logit', 0.0)])])])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"all_nbest_json"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-06T10:11:48.028473Z",
"start_time": "2018-11-06T10:11:47.996311Z"
}
},
"outputs": [
......@@ -1068,11 +1289,11 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 14,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:14.995220Z",
"start_time": "2018-11-05T13:58:14.959782Z"
"end_time": "2018-11-06T10:11:48.060658Z",
"start_time": "2018-11-06T10:11:48.030289Z"
}
},
"outputs": [],
......@@ -1091,11 +1312,11 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 15,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:15.319031Z",
"start_time": "2018-11-05T13:58:14.997465Z"
"end_time": "2018-11-06T10:11:48.478814Z",
"start_time": "2018-11-06T10:11:48.062585Z"
}
},
"outputs": [],
......@@ -1106,11 +1327,11 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 16,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:15.352000Z",
"start_time": "2018-11-05T13:58:15.321118Z"
"end_time": "2018-11-06T10:11:48.512607Z",
"start_time": "2018-11-06T10:11:48.480729Z"
}
},
"outputs": [],
......@@ -1120,11 +1341,11 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 17,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:17.639270Z",
"start_time": "2018-11-05T13:58:15.353829Z"
"end_time": "2018-11-06T10:11:51.023405Z",
"start_time": "2018-11-06T10:11:48.514306Z"
},
"scrolled": true
},
......@@ -1135,7 +1356,7 @@
"tensor([0., 0.])"
]
},
"execution_count": 14,
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
......@@ -1151,11 +1372,11 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 18,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:17.765324Z",
"start_time": "2018-11-05T13:58:17.641050Z"
"end_time": "2018-11-06T10:11:51.079364Z",
"start_time": "2018-11-06T10:11:51.028228Z"
},
"code_folding": []
},
......@@ -1169,9 +1390,9 @@
"all_end_positions = torch.tensor([[f.end_position] for f in eval_features], dtype=torch.long)\n",
"\n",
"eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,\n",
" all_start_positions, all_end_positions)\n",
" all_start_positions, all_end_positions, all_example_index)\n",
"eval_sampler = SequentialSampler(eval_data)\n",
"eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=2)\n",
"eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=1)\n",
"\n",
"model.eval()\n",
"None"
......@@ -1179,11 +1400,11 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 19,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:17.802459Z",
"start_time": "2018-11-05T13:58:17.767029Z"
"end_time": "2018-11-06T10:11:51.114686Z",
"start_time": "2018-11-06T10:11:51.081474Z"
}
},
"outputs": [
......@@ -1191,34 +1412,34 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[torch.Size([2, 384]), torch.Size([2, 384]), torch.Size([2, 384]), torch.Size([2, 1]), torch.Size([2, 1])]\n"
"[torch.Size([1, 384]), torch.Size([1, 384]), torch.Size([1, 384]), torch.Size([1, 1]), torch.Size([1, 1]), torch.Size([1])]\n"
]
},
{
"data": {
"text/plain": [
"torch.Size([2, 1])"
"torch.Size([1, 1])"
]
},
"execution_count": 16,
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"batch = iter(eval_dataloader).next()\n",
"input_ids, input_mask, segment_ids, start_positions, end_positions = batch\n",
"input_ids, input_mask, segment_ids, start_positions, end_positions, example_index = batch\n",
"print([t.shape for t in batch])\n",
"start_positions.size()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 20,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:20.176106Z",
"start_time": "2018-11-05T13:58:17.803997Z"
"end_time": "2018-11-06T10:11:52.298367Z",
"start_time": "2018-11-06T10:11:51.116219Z"
}
},
"outputs": [
......@@ -1226,14 +1447,14 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Evaluating: 0%| | 0/135 [00:00<?, ?it/s]\n"
"Evaluating: 0%| | 0/270 [00:00<?, ?it/s]\n"
]
}
],
"source": [
"pytorch_all_out = []\n",
"for batch in tqdm(eval_dataloader, desc=\"Evaluating\"):\n",
" input_ids, input_mask, segment_ids, start_positions, end_positions = batch\n",
" input_ids, input_mask, segment_ids, start_positions, end_positions, example_index = batch\n",
" input_ids = input_ids.to(device)\n",
" input_mask = input_mask.to(device)\n",
" segment_ids = segment_ids.to(device)\n",
......@@ -1242,26 +1463,25 @@
"\n",
" total_loss, (start_logits, end_logits) = model(input_ids, segment_ids, input_mask, start_positions, end_positions)\n",
" \n",
" unique_id = int(result[\"unique_ids\"])\n",
" eval_feature = eval_unique_id_to_feature[unique_id]\n",
" eval_feature = eval_features[example_index.item()]\n",
"\n",
" output_json = collections.OrderedDict()\n",
" output_json[\"linex_index\"] = unique_id\n",
" output_json[\"tokens\"] = [token for (i, token) in enumerate(eval_feature.tokens)]\n",
" output_json[\"total_loss\"] = result[\"total_loss\"]\n",
" output_json[\"start_logits\"] = result[\"start_logits\"]\n",
" output_json[\"end_logits\"] = result[\"end_logits\"]\n",
" output_json[\"total_loss\"] = total_loss.detach().cpu().numpy()\n",
" output_json[\"start_logits\"] = start_logits.detach().cpu().numpy()\n",
" output_json[\"end_logits\"] = end_logits.detach().cpu().numpy()\n",
" pytorch_all_out.append(output_json)\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 21,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:20.229288Z",
"start_time": "2018-11-05T13:58:20.180586Z"
"end_time": "2018-11-06T10:11:52.339553Z",
"start_time": "2018-11-06T10:11:52.300335Z"
}
},
"outputs": [
......@@ -1273,8 +1493,8 @@
"5\n",
"odict_keys(['linex_index', 'tokens', 'total_loss', 'start_logits', 'end_logits'])\n",
"number of tokens 176\n",
"number of start_logits 384\n",
"number of end_logits 384\n"
"number of start_logits 1\n",
"number of end_logits 1\n"
]
}
],
......@@ -1289,11 +1509,11 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 22,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:20.272079Z",
"start_time": "2018-11-05T13:58:20.231730Z"
"end_time": "2018-11-06T10:11:52.372827Z",
"start_time": "2018-11-06T10:11:52.341393Z"
}
},
"outputs": [],
......@@ -1310,11 +1530,11 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 23,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:20.311258Z",
"start_time": "2018-11-05T13:58:20.273919Z"
"end_time": "2018-11-06T10:11:52.402814Z",
"start_time": "2018-11-06T10:11:52.374329Z"
}
},
"outputs": [],
......@@ -1324,11 +1544,11 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 24,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:20.354231Z",
"start_time": "2018-11-05T13:58:20.313037Z"
"end_time": "2018-11-06T10:11:52.434743Z",
"start_time": "2018-11-06T10:11:52.404345Z"
}
},
"outputs": [
......@@ -1337,9 +1557,9 @@
"output_type": "stream",
"text": [
"shape tensorflow layer, shape pytorch layer, standard deviation\n",
"((384,), (384,), 2.813107017184736e-07)\n",
"((384,), (384,), 2.813107017184736e-07)\n",
"((1,), (1,), 2.0812988310581204e-07)\n"
"((384,), (1, 384), 5.244962470555037e-06)\n",
"((384,), (1, 384), 5.244962470555037e-06)\n",
"((1,), (), 4.560241698925438e-06)\n"
]
}
],
......@@ -1352,15 +1572,25 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 27,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:20.394369Z",
"start_time": "2018-11-05T13:58:20.356033Z"
"end_time": "2018-11-06T10:12:54.200059Z",
"start_time": "2018-11-06T10:12:54.167355Z"
}
},
"outputs": [],
"source": []
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total loss of the TF model 9.06024 - Total loss of the PT model 9.0602445602417\n"
]
}
],
"source": [
"print(\"Total loss of the TF model {} - Total loss of the PT model {}\".format(tensorflow_outputs[2][0], pytorch_outputs[2]))"
]
},
{
"cell_type": "code",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment