"src/array/vscode:/vscode.git/clone" did not exist on "85b8fe523644edf56ed182b932f156811a346b99"
Untitled.ipynb 20.9 KB
Newer Older
VoVAllen's avatar
VoVAllen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: DGL_LIBRARY_PATH=/data/jinjing/dgl/build\n"
     ]
    }
   ],
   "source": [
    "%env DGL_LIBRARY_PATH=/data/jinjing/dgl/build"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch as th\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_points=10\n",
    "num_points2=10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "theta_dist1=np.pi*(np.random.rand(num_points)*2)\n",
    "theta_dist2=np.pi*(np.random.rand(num_points2)*2)\n",
    "\n",
    "radius_1=np.random.rand(num_points)\n",
    "radius_2=np.random.rand(num_points2)*2+3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "points_1=np.array([np.cos(theta_dist1)*radius_1,np.sin(theta_dist1)*radius_1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "points_2=np.array([np.cos(theta_dist2)*radius_2,np.sin(theta_dist2)*radius_2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Populating the interactive namespace from numpy and matplotlib\n"
     ]
    }
   ],
   "source": [
    "%pylab inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "point_3=np.array([1,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.collections.PathCollection at 0x7fd47d377cf8>"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAEBlJREFUeJzt3X2IXNd9xvHnkSKTbGIwQZM4tbQ7hpqAUB27DCLBhVLFCYprYhII2EwCIYH9pwEHDCbuQktaBAVDmkICYWiUUDrEFBLh4JfaMlUwhvhl5UqqHMnBGK8skaI1JdhmIYmsX/+4M9Vqu28z98zcmXO/HxCz9+7l3B/S6uHsueee44gQACAfO6ouAACQFsEOAJkh2AEgMwQ7AGSGYAeAzBDsAJAZgh0AMkOwA0BmCHYAyMz7qrjp7t27o9lsVnFrAJhaJ06ceCsiGltdV0mwN5tNLS4uVnFrAJhatpe2cx1DMQCQmSQ9dttvSHpH0nuSLkdEK0W7AIDBpRyK+YuIeCthewCAITAUAwCZSRXsIelp2ydszydqEwAwhFRDMX8WERdtf0TSMdvnIuLZ1Rf0An9ekmZnZxPdFgCwVpIee0Rc7H1eknRU0oF1rulERCsiWo3GltMwgaS6XanZlHbsKD673aorAkandLDb/qDt6/tfS/qspDNl2wVS6Xal+XlpaUmKKD7n5wl35CtFj/2jkp6zfUrSi5Iej4h/T9AukMTCgrSycu25lZXiPJCj0mPsEfG6pE8kqAUYifPnBzsPTDumOyJ7Gz2r5xk+ckWwI3uHD0szM9eem5kpzgM5ItiRvXZb6nSkuTnJLj47neI8kKNKVncExq3dJshRH/TYASAzBDsAZIZgB4DMEOwARorlHMaPYAdqooqAZTmHahDsQA1UFbAs51ANgh2ogaoCluUcqkGwAzVQVcCynEM1CHagBqoKWJZzqAbBDtRAVQHLcg7VYEkBoAb6QbqwUAy/zM4WoT6OgGU5h/Ej2IGaIGDrg6EYAMgMwQ4AmUkW7LZ32v5P24+lahMAMLiUPfb7JZ1N2B4AYAhJgt32Hkl/KemfU7QHABheqh77dyU9KOnKRhfYnre9aHtxeXk50W0BAGuVDnbbd0u6FBEnNrsuIjoR0YqIVqPRKHtbAMAGUvTY75D0edtvSHpE0kHb/5qgXQDAEEoHe0Q8FBF7IqIp6V5J/xERXy5dGQBgKMxjB4DMJF1SICJ+IekXKdsEAAyGHjuA7NVt31UWAQOQtf62gP0dpPrbAkr5LopGjx1A1uq47yrBDiBrddx3lWAHkLU67rtKsAPIWh33XSXYUUt1myVRZ3Xcd5VZMaidOs6SqLu6bQtIjx21U8dZEqgXgh21U8dZEqgXgh21U8dZEqgXgh21U8dZEqgXgh21U8dZEqgXZsWgluo2SwL1Qo8dADJDsANAZgh2AMhM6WC3/X7bL9o+ZfsV299OURgAYDgpHp7+TtLBiHjX9i5Jz9l+MiKeT9A2AGBApXvsUXi3d7ir9yfKtgsAw2CBt0Rj7LZ32j4p6ZKkYxHxQop2AWAQ/QXelpakiKsLvNUt3JMEe0S8FxG3Sdoj6YDt/WuvsT1ve9H24vLycorbAsA1WOCtkHRWTET8VtJxSYfW+V4nIloR0Wo0GilvCwCSWOCtL8WsmIbtG3pff0DSZySdK9suAAyKBd4KKXrsH5N03PZpSS+pGGN/LEG7ADAQFngrpJgVczoibo+IWyNif0T8XYrCMBhmAgAs8NbHImAZYKs34CoWeGNJgSwwEwDAagR7BpgJAGA1gj0DzAQAsBrBngFmAgBYjWDPADMBAKzGrJhMMBMAQB89dgDIDMEOAJkh2AEgMwQ7AGSGYAeAERv3Wk7MigGAEapiLSd67AAwQlWs5USwA8AIVbGWE8EOACNUxVpOBDsAjFAVazkR7AAwQlWs5VR6VoztvZL+RdJHJYWkTkT8U9l2ASAX417LKcV0x8uSHoiIl21fL+mE7WMR8asEbQMABpRiM+vfRMTLva/fkXRW0k1l2wUADCfpGLvtpqTbJb2Qsl0AwPYlC3bbH5L0U0nfjIi31/n+vO1F24vLy8upbgsAWCNJsNvepSLUuxHxs/WuiYhORLQiotVoNFLcFgCwjtLBbtuSfijpbER8p3xJAIAyUvTY75D0FUkHbZ/s/bkrQbsAgCGUnu4YEc9JcoJaAAAJ8OYpAGSGYAeAzBDsAJAZgh0AMkOwA0BmCHYAyAzBDgCZIdgBIDMEOwBkhmAHgMwQ7ACQGYIdADJDsANAZgh2AMgMwQ4AmSHYASAzBDsAZIZgB4DMJAl220dsX7J9JkV7AIDhpeqx/1jSoURtARijrqSmijBo9o4x3UpvZi1JEfGs7WaKtgCMT1fSvKSV3vFS71iS2pVUhBQYYwdqbEFXQ71vpXce02tswW573vai7cXl5eVx3RbAJs4PeB7TYWzBHhGdiGhFRKvRaIzrtgA2MTvgeUwHhmKAGjssaWbNuZneeUyvVNMdfyLpl5I+bvuC7a+naBfAaLUldSTNSXLvsyMenE67VLNi7kvRDoDxa4sgzw1DMcC4dLtSsynt2FF8dpkxjtFI0mMHsIVuV5qfl1Z6kwuXlopjSWrTX0Za9NiBcVhYuBrqfSsrxXkgMYIdGIfzG8wM3+g8UALBDozD7AYzwzc6D5RAsAPjcPiwNLNmxvjMTHEeSIxgB8ah3ZY6HWluTrKLz06HB6cYCWbFAOPSbhPkGAt67ACQGYIdADJDsANAZgh2AMgMwQ4AmSHYASAzBDswCilWcmQ1SAyJYAdS66/kuLQkRVxdybHb3X5Yb9YGsAWCHdjMML3mjVZyvP/+7Yc1q0GiBEfE2G/aarVicXFx7PcFBrJ2DXWpWN9lq6UAduwognu75uakN97YXhu2dOXK9ttGVmyfiIjWVtel2vP0kO1Xbb9m+1sp2gQqN2yv+cMfHuw+6y3dy2qQKKF0sNveKen7kj4naZ+k+2zvK9suULlh1lDvdqW33x7sPuuFNatBooQUPfYDkl6LiNcj4veSHpF0T4J2gWoN02teWJD+8Ift32OjsGY1SJSQIthvkvTmquMLvXPXsD1ve9H24vLycoLbAiM2TK950B2RNgvrdrsYe79ypfgk1LFNY5sVExGdiGhFRKvRaIzrtsDwhuk1DzIGPjdHWGMkUgT7RUl7Vx3v6Z0Dpt+gveb1evm7dknXXXftOcbLMUIpgv0lSbfYvtn2dZLulfTzBO0C02e9Xv6PfiQdOcJ4OcYmyTx223dJ+q6knZKORMSmXRHmsWOqdbvFQ9Lz54uhl8OHCWmMxXbnsSfZGi8inpD0RIq2gIm29qWl/tujEuGOicGSAsAgeNUfU4BgBwYxzEtLwJgR7MAgeNUfU4BgBwbBq/6YAgQ7MAhe9ccUSDIrBqiVdpsgx0Sjxw4AmSHYgdWmeZ/Raa4dSU1VsPNzi5Ga5n1Gp7l2JDc1W+MNu0sZsG3NZhGIa623dd2kmebasW3bXVJgaoKdn1uM3DTvMzrNtWPbxrrn6Tjwwh9GbppfPprm2pHc1AQ7P7cYuWl++Wiaa0dyUxPs/Nxi5Kb55aNprh3JTc0Yu8Qy2ADqbazrsY8LL/wBwNamZigGALA9pYLd9pdsv2L7iu0tfz0AAIxe2R77GUlflPRsgloAAAmUGmOPiLOSZDtNNQCA0hhjB4DMbNljt/2MpBvX+dZCRDy63RvZnpc0L0mzvFUEACOzZbBHxJ0pbhQRHUkdqZjHnqJNAMD/x1AMAGSm7HTHL9i+IOlTkh63/VSasgAAwyo7K+aopKOJagEAJMBQDABkhmAHgMwQ7ACQGYIdADJDsANAZgh2AMgMwQ4AmSHYASAzBDsAZIZgB4DMEOwAkBmCHQAyQ7ADQGYIdgDIDMEOAJkh2AEgMwQ7AGSGYAeAzJTd8/Rh2+dsn7Z91PYNqQoDAAynbI/9mKT9EXGrpF9Leqh8SQCAMkoFe0Q8HRGXe4fPS9pTviQAQBkpx9i/JunJhO0BAIbwvq0usP2MpBvX+dZCRDzau2ZB0mVJ3U3amZc0L0mzs7NDFQsA2NqWwR4Rd272fdtflXS3pE9HRGzSTkdSR5JardaG1wEAytky2Ddj+5CkByX9eUSspCkJAFBG2TH270m6XtIx2ydt/yBBTQCAEkr12CPij1MVAgBIgzdPASAzBDsAZIZgB4DMEOwAkBmCHQAyQ7AjiW5XajalHTuKz+6G7yADGLVS0x0BqQjx+XlppfeK2tJScSxJ7XZ1dQF1RY8dpS0sXA31vpWV4jyA8SPYUdr584OdBzBaBDtK22ixThbxBKpBsKO0w4elmZlrz83MFOcBjB/BjtLabanTkebmJLv47HR4cApUhVkxSKLdJsiBSUGPHQAyQ7ADQGYIdgDIDMEOAJkh2AEgMwQ7AGTGETH+m9rLkpbGfuOrdkt6q8L7b2aSa5Mmu75Jrk2ivjImuTZpfPXNRURjq4sqCfaq2V6MiFbVdaxnkmuTJru+Sa5Nor4yJrk2afLqYygGADJDsANAZuoa7J2qC9jEJNcmTXZ9k1ybRH1lTHJt0oTVV8sxdgDIWV177ACQrdoHu+0HbIft3VXX0mf7722ftn3S9tO2/6jqmlaz/bDtc70aj9q+oeqa+mx/yfYrtq/YnphZCrYP2X7V9mu2v1V1PavZPmL7ku0zVdeylu29to/b/lXv3/X+qmvqs/1+2y/aPtWr7dtV19RX62C3vVfSZyVN2iZuD0fErRFxm6THJP1N1QWtcUzS/oi4VdKvJT1UcT2rnZH0RUnPVl1In+2dkr4v6XOS9km6z/a+aqu6xo8lHaq6iA1clvRAROyT9ElJfzVBf3e/k3QwIj4h6TZJh2x/suKaJNU82CX9o6QHJU3Ug4aIeHvV4Qc1efU9HRGXe4fPS9pTZT2rRcTZiHi16jrWOCDptYh4PSJ+L+kRSfdUXNP/iYhnJf1P1XWsJyJ+ExEv975+R9JZSTdVW1UhCu/2Dnf1/kzE/9XaBrvteyRdjIhTVdeyHtuHbb8pqa3J67Gv9jVJT1ZdxIS7SdKbq44vaELCaZrYbkq6XdIL1VZyle2dtk9KuiTpWERMRG1Z76Bk+xlJN67zrQVJf61iGKYSm9UWEY9GxIKkBdsPSfqGpL+dpPp61yyo+FW5O2m1IS+2PyTpp5K+ueY32kpFxHuSbus9Zzpqe39EVP6sIutgj4g71ztv+08k3SzplG2pGEp42faBiPjvKmtbR1fSExpzsG9Vn+2vSrpb0qdjzHNmB/i7mxQXJe1ddbyndw7bYHuXilDvRsTPqq5nPRHxW9vHVTyrqDzYazkUExH/FREfiYhmRDRV/Gr8p+MK9a3YvmXV4T2SzlVVy3psH1LxbOLzEbFSdT1T4CVJt9i+2fZ1ku6V9POKa5oKLnpeP5R0NiK+U3U9q9lu9GeE2f6ApM9oQv6v1jLYp8A/2D5j+7SK4aKJmeLV8z1J10s61puS+YOqC+qz/QXbFyR9StLjtp+quqbeg+ZvSHpKxcO/f4uIV6qt6irbP5H0S0kft33B9terrmmVOyR9RdLB3s/aSdt3VV1Uz8ckHe/9P31JxRj7YxXXJIk3TwEgO/TYASAzBDsAZIZgB4DMEOwAkBmCHQAyQ7ADQGYIdgDIDMEOAJn5XxKAS7tplJrgAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.scatter(*points_1,color='red')\n",
    "plt.scatter(*points_2,color='blue')\n",
    "plt.scatter(*point_3,color='aqua')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch as th\n",
    "import dgl\n",
    "import networkx as nx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "g=dgl.DGLGraph()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "g.add_nodes(num_points+num_points2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(num_points):\n",
    "    for j in range(num_points2):\n",
    "        g.add_edge(i,j+num_points)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "nodes1=list(range(num_points))\n",
    "nodes2=list(range(num_points,num_points+num_points2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "g.set_n_repr({'h':th.from_numpy(np.concatenate([points_1,points_2],axis=1).T)})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [],
   "source": [
    "g.update_edge(edge_func=lambda u,v,edge:{'u_hat':u['h']})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [],
   "source": [
    "g.set_e_repr({'b_ij':th.zeros(100)})\n",
    "g.set_e_repr({'u_hat':})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_edge(self, u, v, edge):\n",
    "    return {'b_ij': edge['b_ij'] + (v['h'] * edge['u_hat']).mean(dim=1).sum(dim=1)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "metadata": {},
   "outputs": [],
   "source": [
    "def v2_reduce(node, msg):\n",
    "    b_ij_c, u_hat = msg['b_ij'], msg['u_hat']\n",
    "    c_i = F.softmax(b_ij_c, dim=0)\n",
    "    s_j = (c_i.unsqueeze(2).unsqueeze(3) * u_hat).sum(dim=1)\n",
    "    return {'h': s_j}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {},
   "outputs": [],
   "source": [
    "def capsule_msg(src, edge):\n",
    "    return {'b_ij': edge['b_ij'], 'h': src['h'], 'u_hat': edge['u_hat']}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 124,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "20"
      ]
     },
     "execution_count": 124,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "g.number_of_nodes()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> \u001b[0;32m<ipython-input-120-f6b14912c0b0>\u001b[0m(4)\u001b[0;36mv2_reduce\u001b[0;34m()\u001b[0m\n",
      "\u001b[0;32m      1 \u001b[0;31m\u001b[0;32mdef\u001b[0m \u001b[0mv2_reduce\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\u001b[0;32m      2 \u001b[0;31m    \u001b[0mb_ij_c\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mu_hat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'b_ij'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'u_hat'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\u001b[0;32m      3 \u001b[0;31m    \u001b[0mc_i\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msoftmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb_ij_c\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\u001b[0;32m----> 4 \u001b[0;31m    \u001b[0ms_j\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mc_i\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mu_hat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\u001b[0;32m      5 \u001b[0;31m    \u001b[0;32mreturn\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'h'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0ms_j\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0m\n",
      "ipdb> c_i.shape\n",
      "torch.Size([10, 40])\n",
      "ipdb> u_hat.shape\n",
      "torch.Size([10, 40, 2])\n",
      "ipdb> q\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "The size of tensor a (50) must match the size of tensor b (10) at non-singleton dimension 1",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-123-847b35d9c410>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mget_ipython\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_line_magic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'debug'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m''\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_all\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcapsule_msg\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mv2_reduce\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/.pyenv/versions/3.6.1/lib/python3.6/site-packages/dgl-0.0.1-py3.6.egg/dgl/graph.py\u001b[0m in \u001b[0;36mupdate_all\u001b[0;34m(self, message_func, reduce_func, apply_node_func)\u001b[0m\n\u001b[1;32m   1138\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1139\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mALL\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mALL\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmessage_func\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1140\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrecv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mALL\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1141\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_nodes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mALL\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapply_node_func\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1142\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.pyenv/versions/3.6.1/lib/python3.6/site-packages/dgl-0.0.1-py3.6.egg/dgl/graph.py\u001b[0m in \u001b[0;36mrecv\u001b[0;34m(self, u, reduce_func, apply_node_func)\u001b[0m\n\u001b[1;32m    944\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreduce_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    945\u001b[0m             \u001b[0mreduce_func\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mBundledReduceFunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreduce_func\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 946\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_batch_recv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce_func\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    947\u001b[0m         \u001b[0;31m# optional apply nodes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    948\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_nodes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapply_node_func\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.pyenv/versions/3.6.1/lib/python3.6/site-packages/dgl-0.0.1-py3.6.egg/dgl/graph.py\u001b[0m in \u001b[0;36m_batch_recv\u001b[0;34m(self, v, reduce_func)\u001b[0m\n\u001b[1;32m    990\u001b[0m                         lambda key: _reshape_fn(in_msgs[key]), self._msg_frame.schemes)\n\u001b[1;32m    991\u001b[0m             \u001b[0mreordered_v\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv_bkt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtousertensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 992\u001b[0;31m             \u001b[0mnew_reprs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreduce_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdst_reprs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreshaped_in_msgs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    993\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    994\u001b[0m         \u001b[0;31m# TODO: clear partial messages\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-120-f6b14912c0b0>\u001b[0m in \u001b[0;36mv2_reduce\u001b[0;34m(node, msg)\u001b[0m\n\u001b[1;32m      2\u001b[0m     \u001b[0mb_ij_c\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mu_hat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'b_ij'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'u_hat'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m     \u001b[0mc_i\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msoftmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb_ij_c\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m     \u001b[0ms_j\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mc_i\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mu_hat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'h'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0ms_j\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (50) must match the size of tensor b (10) at non-singleton dimension 1"
     ]
    }
   ],
   "source": [
    "%debug\n",
    "g.update_all(capsule_msg,v2_reduce,lambda x:x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g.set_e_repr({'u_hat':th.zeros()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "b_ij="
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g.set_e_repr()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "nxg=g.to_networkx()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.6",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}