{ "cells": [ { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "env: DGL_LIBRARY_PATH=/data/jinjing/dgl/build\n" ] } ], "source": [ "%env DGL_LIBRARY_PATH=/data/jinjing/dgl/build" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch as th\n", "import numpy as np\n", "import torch.nn as nn\n", "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "num_points=10\n", "num_points2=10" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "theta_dist1=np.pi*(np.random.rand(num_points)*2)\n", "theta_dist2=np.pi*(np.random.rand(num_points2)*2)\n", "\n", "radius_1=np.random.rand(num_points)\n", "radius_2=np.random.rand(num_points2)*2+3" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "points_1=np.array([np.cos(theta_dist1)*radius_1,np.sin(theta_dist1)*radius_1])" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "points_2=np.array([np.cos(theta_dist2)*radius_2,np.sin(theta_dist2)*radius_2])" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Populating the interactive namespace from numpy and matplotlib\n" ] } ], "source": [ "%pylab inline" ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [], "source": [ "point_3=np.array([1,1])" ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 74, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAD8CAYAAABjAo9vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAEBlJREFUeJzt3X2IXNd9xvHnkSKTbGIwQZM4tbQ7hpqAUB27DCLBhVLFCYprYhII2EwCIYH9pwEHDCbuQktaBAVDmkICYWiUUDrEFBLh4JfaMlUwhvhl5UqqHMnBGK8skaI1JdhmIYmsX/+4M9Vqu28z98zcmXO/HxCz9+7l3B/S6uHsueee44gQACAfO6ouAACQFsEOAJkh2AEgMwQ7AGSGYAeAzBDsAJAZgh0AMkOwA0BmCHYAyMz7qrjp7t27o9lsVnFrAJhaJ06ceCsiGltdV0mwN5tNLS4uVnFrAJhatpe2cx1DMQCQmSQ9dttvSHpH0nuSLkdEK0W7AIDBpRyK+YuIeCthewCAITAUAwCZSRXsIelp2ydszydqEwAwhFRDMX8WERdtf0TSMdvnIuLZ1Rf0An9ekmZnZxPdFgCwVpIee0Rc7H1eknRU0oF1rulERCsiWo3GltMwgaS6XanZlHbsKD673aorAkandLDb/qDt6/tfS/qspDNl2wVS6Xal+XlpaUmKKD7n5wl35CtFj/2jkp6zfUrSi5Iej4h/T9AukMTCgrSycu25lZXiPJCj0mPsEfG6pE8kqAUYifPnBzsPTDumOyJ7Gz2r5xk+ckWwI3uHD0szM9eem5kpzgM5ItiRvXZb6nSkuTnJLj47neI8kKNKVncExq3dJshRH/TYASAzBDsAZIZgB4DMEOwARorlHMaPYAdqooqAZTmHahDsQA1UFbAs51ANgh2ogaoCluUcqkGwAzVQVcCynEM1CHagBqoKWJZzqAbBDtRAVQHLcg7VYEkBoAb6QbqwUAy/zM4WoT6OgGU5h/Ej2IGaIGDrg6EYAMgMwQ4AmUkW7LZ32v5P24+lahMAMLiUPfb7JZ1N2B4AYAhJgt32Hkl/KemfU7QHABheqh77dyU9KOnKRhfYnre9aHtxeXk50W0BAGuVDnbbd0u6FBEnNrsuIjoR0YqIVqPRKHtbAMAGUvTY75D0edtvSHpE0kHb/5qgXQDAEEoHe0Q8FBF7IqIp6V5J/xERXy5dGQBgKMxjB4DMJF1SICJ+IekXKdsEAAyGHjuA7NVt31UWAQOQtf62gP0dpPrbAkr5LopGjx1A1uq47yrBDiBrddx3lWAHkLU67rtKsAPIWh33XSXYUUt1myVRZ3Xcd5VZMaidOs6SqLu6bQtIjx21U8dZEqgXgh21U8dZEqgXgh21U8dZEqgXgh21U8dZEqgXgh21U8dZEqgXZsWgluo2SwL1Qo8dADJDsANAZgh2AMhM6WC3/X7bL9o+ZfsV299OURgAYDgpHp7+TtLBiHjX9i5Jz9l+MiKeT9A2AGBApXvsUXi3d7ir9yfKtgsAw2CBt0Rj7LZ32j4p6ZKkYxHxQop2AWAQ/QXelpakiKsLvNUt3JMEe0S8FxG3Sdoj6YDt/WuvsT1ve9H24vLycorbAsA1WOCtkHRWTET8VtJxSYfW+V4nIloR0Wo0GilvCwCSWOCtL8WsmIbtG3pff0DSZySdK9suAAyKBd4KKXrsH5N03PZpSS+pGGN/LEG7ADAQFngrpJgVczoibo+IWyNif0T8XYrCMBhmAgAs8NbHImAZYKs34CoWeGNJgSwwEwDAagR7BpgJAGA1gj0DzAQAsBrBngFmAgBYjWDPADMBAKzGrJhMMBMAQB89dgDIDMEOAJkh2AEgMwQ7AGSGYAeAERv3Wk7MigGAEapiLSd67AAwQlWs5USwA8AIVbGWE8EOACNUxVpOBDsAjFAVazkR7AAwQlWs5VR6VoztvZL+RdJHJYWkTkT8U9l2ASAX417LKcV0x8uSHoiIl21fL+mE7WMR8asEbQMABpRiM+vfRMTLva/fkXRW0k1l2wUADCfpGLvtpqTbJb2Qsl0AwPYlC3bbH5L0U0nfjIi31/n+vO1F24vLy8upbgsAWCNJsNvepSLUuxHxs/WuiYhORLQiotVoNFLcFgCwjtLBbtuSfijpbER8p3xJAIAyUvTY75D0FUkHbZ/s/bkrQbsAgCGUnu4YEc9JcoJaAAAJ8OYpAGSGYAeAzBDsAJAZgh0AMkOwA0BmCHYAyAzBDgCZIdgBIDMEOwBkhmAHgMwQ7ACQGYIdADJDsANAZgh2AMgMwQ4AmSHYASAzBDsAZIZgB4DMJAl220dsX7J9JkV7AIDhpeqx/1jSoURtARijrqSmijBo9o4x3UpvZi1JEfGs7WaKtgCMT1fSvKSV3vFS71iS2pVUhBQYYwdqbEFXQ71vpXce02tswW573vai7cXl5eVx3RbAJs4PeB7TYWzBHhGdiGhFRKvRaIzrtgA2MTvgeUwHhmKAGjssaWbNuZneeUyvVNMdfyLpl5I+bvuC7a+naBfAaLUldSTNSXLvsyMenE67VLNi7kvRDoDxa4sgzw1DMcC4dLtSsynt2FF8dpkxjtFI0mMHsIVuV5qfl1Z6kwuXlopjSWrTX0Za9NiBcVhYuBrqfSsrxXkgMYIdGIfzG8wM3+g8UALBDozD7AYzwzc6D5RAsAPjcPiwNLNmxvjMTHEeSIxgB8ah3ZY6HWluTrKLz06HB6cYCWbFAOPSbhPkGAt67ACQGYIdADJDsANAZgh2AMgMwQ4AmSHYASAzBDswCilWcmQ1SAyJYAdS66/kuLQkRVxdybHb3X5Yb9YGsAWCHdjMML3mjVZyvP/+7Yc1q0GiBEfE2G/aarVicXFx7PcFBrJ2DXWpWN9lq6UAduwognu75uakN97YXhu2dOXK9ttGVmyfiIjWVtel2vP0kO1Xbb9m+1sp2gQqN2yv+cMfHuw+6y3dy2qQKKF0sNveKen7kj4naZ+k+2zvK9suULlh1lDvdqW33x7sPuuFNatBooQUPfYDkl6LiNcj4veSHpF0T4J2gWoN02teWJD+8Ift32OjsGY1SJSQIthvkvTmquMLvXPXsD1ve9H24vLycoLbAiM2TK950B2RNgvrdrsYe79ypfgk1LFNY5sVExGdiGhFRKvRaIzrtsDwhuk1DzIGPjdHWGMkUgT7RUl7Vx3v6Z0Dpt+gveb1evm7dknXXXftOcbLMUIpgv0lSbfYvtn2dZLulfTzBO0C02e9Xv6PfiQdOcJ4OcYmyTx223dJ+q6knZKORMSmXRHmsWOqdbvFQ9Lz54uhl8OHCWmMxXbnsSfZGi8inpD0RIq2gIm29qWl/tujEuGOicGSAsAgeNUfU4BgBwYxzEtLwJgR7MAgeNUfU4BgBwbBq/6YAgQ7MAhe9ccUSDIrBqiVdpsgx0Sjxw4AmSHYgdWmeZ/Raa4dSU1VsPNzi5Ga5n1Gp7l2JDc1W+MNu0sZsG3NZhGIa623dd2kmebasW3bXVJgaoKdn1uM3DTvMzrNtWPbxrrn6Tjwwh9GbppfPprm2pHc1AQ7P7cYuWl++Wiaa0dyUxPs/Nxi5Kb55aNprh3JTc0Yu8Qy2ADqbazrsY8LL/wBwNamZigGALA9pYLd9pdsv2L7iu0tfz0AAIxe2R77GUlflPRsgloAAAmUGmOPiLOSZDtNNQCA0hhjB4DMbNljt/2MpBvX+dZCRDy63RvZnpc0L0mzvFUEACOzZbBHxJ0pbhQRHUkdqZjHnqJNAMD/x1AMAGSm7HTHL9i+IOlTkh63/VSasgAAwyo7K+aopKOJagEAJMBQDABkhmAHgMwQ7ACQGYIdADJDsANAZgh2AMgMwQ4AmSHYASAzBDsAZIZgB4DMEOwAkBmCHQAyQ7ADQGYIdgDIDMEOAJkh2AEgMwQ7AGSGYAeAzJTd8/Rh2+dsn7Z91PYNqQoDAAynbI/9mKT9EXGrpF9Leqh8SQCAMkoFe0Q8HRGXe4fPS9pTviQAQBkpx9i/JunJhO0BAIbwvq0usP2MpBvX+dZCRDzau2ZB0mVJ3U3amZc0L0mzs7NDFQsA2NqWwR4Rd272fdtflXS3pE9HRGzSTkdSR5JardaG1wEAytky2Ddj+5CkByX9eUSspCkJAFBG2TH270m6XtIx2ydt/yBBTQCAEkr12CPij1MVAgBIgzdPASAzBDsAZIZgB4DMEOwAkBmCHQAyQ7AjiW5XajalHTuKz+6G7yADGLVS0x0BqQjx+XlppfeK2tJScSxJ7XZ1dQF1RY8dpS0sXA31vpWV4jyA8SPYUdr584OdBzBaBDtK22ixThbxBKpBsKO0w4elmZlrz83MFOcBjB/BjtLabanTkebmJLv47HR4cApUhVkxSKLdJsiBSUGPHQAyQ7ADQGYIdgDIDMEOAJkh2AEgMwQ7AGTGETH+m9rLkpbGfuOrdkt6q8L7b2aSa5Mmu75Jrk2ivjImuTZpfPXNRURjq4sqCfaq2V6MiFbVdaxnkmuTJru+Sa5Nor4yJrk2afLqYygGADJDsANAZuoa7J2qC9jEJNcmTXZ9k1ybRH1lTHJt0oTVV8sxdgDIWV177ACQrdoHu+0HbIft3VXX0mf7722ftn3S9tO2/6jqmlaz/bDtc70aj9q+oeqa+mx/yfYrtq/YnphZCrYP2X7V9mu2v1V1PavZPmL7ku0zVdeylu29to/b/lXv3/X+qmvqs/1+2y/aPtWr7dtV19RX62C3vVfSZyVN2iZuD0fErRFxm6THJP1N1QWtcUzS/oi4VdKvJT1UcT2rnZH0RUnPVl1In+2dkr4v6XOS9km6z/a+aqu6xo8lHaq6iA1clvRAROyT9ElJfzVBf3e/k3QwIj4h6TZJh2x/suKaJNU82CX9o6QHJU3Ug4aIeHvV4Qc1efU9HRGXe4fPS9pTZT2rRcTZiHi16jrWOCDptYh4PSJ+L+kRSfdUXNP/iYhnJf1P1XWsJyJ+ExEv975+R9JZSTdVW1UhCu/2Dnf1/kzE/9XaBrvteyRdjIhTVdeyHtuHbb8pqa3J67Gv9jVJT1ZdxIS7SdKbq44vaELCaZrYbkq6XdIL1VZyle2dtk9KuiTpWERMRG1Z76Bk+xlJN67zrQVJf61iGKYSm9UWEY9GxIKkBdsPSfqGpL+dpPp61yyo+FW5O2m1IS+2PyTpp5K+ueY32kpFxHuSbus9Zzpqe39EVP6sIutgj4g71ztv+08k3SzplG2pGEp42faBiPjvKmtbR1fSExpzsG9Vn+2vSrpb0qdjzHNmB/i7mxQXJe1ddbyndw7bYHuXilDvRsTPqq5nPRHxW9vHVTyrqDzYazkUExH/FREfiYhmRDRV/Gr8p+MK9a3YvmXV4T2SzlVVy3psH1LxbOLzEbFSdT1T4CVJt9i+2fZ1ku6V9POKa5oKLnpeP5R0NiK+U3U9q9lu9GeE2f6ApM9oQv6v1jLYp8A/2D5j+7SK4aKJmeLV8z1J10s61puS+YOqC+qz/QXbFyR9StLjtp+quqbeg+ZvSHpKxcO/f4uIV6qt6irbP5H0S0kft33B9terrmmVOyR9RdLB3s/aSdt3VV1Uz8ckHe/9P31JxRj7YxXXJIk3TwEgO/TYASAzBDsAZIZgB4DMEOwAkBmCHQAyQ7ADQGYIdgDIDMEOAJn5XxKAS7tplJrgAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.scatter(*points_1,color='red')\n", "plt.scatter(*points_2,color='blue')\n", "plt.scatter(*point_3,color='aqua')" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "import torch as th\n", "import dgl\n", "import networkx as nx" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "g=dgl.DGLGraph()" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "g.add_nodes(num_points+num_points2)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "for i in range(num_points):\n", " for j in range(num_points2):\n", " g.add_edge(i,j+num_points)" ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [], "source": [ "nodes1=list(range(num_points))\n", "nodes2=list(range(num_points,num_points+num_points2))" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [], "source": [ "g.set_n_repr({'h':th.from_numpy(np.concatenate([points_1,points_2],axis=1).T)})" ] }, { "cell_type": "code", "execution_count": 115, "metadata": {}, "outputs": [], "source": [ "g.update_edge(edge_func=lambda u,v,edge:{'u_hat':u['h']})" ] }, { "cell_type": "code", "execution_count": 76, "metadata": {}, "outputs": [], "source": [ "g.set_e_repr({'b_ij':th.zeros(100)})\n", "g.set_e_repr({'u_hat':})" ] }, { "cell_type": "code", "execution_count": 119, "metadata": {}, "outputs": [], "source": [ "def update_edge(self, u, v, edge):\n", " return {'b_ij': edge['b_ij'] + (v['h'] * edge['u_hat']).mean(dim=1).sum(dim=1)}" ] }, { "cell_type": "code", "execution_count": 120, "metadata": {}, "outputs": [], "source": [ "def v2_reduce(node, msg):\n", " b_ij_c, u_hat = msg['b_ij'], msg['u_hat']\n", " c_i = F.softmax(b_ij_c, dim=0)\n", " s_j = (c_i.unsqueeze(2).unsqueeze(3) * u_hat).sum(dim=1)\n", " return {'h': s_j}" ] }, { "cell_type": "code", "execution_count": 121, "metadata": {}, "outputs": [], "source": [ "def capsule_msg(src, edge):\n", " return {'b_ij': edge['b_ij'], 'h': src['h'], 'u_hat': edge['u_hat']}\n" ] }, { "cell_type": "code", "execution_count": 124, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "20" ] }, "execution_count": 124, "metadata": {}, "output_type": "execute_result" } ], "source": [ "g.number_of_nodes()" ] }, { "cell_type": "code", "execution_count": 123, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "> \u001b[0;32m\u001b[0m(4)\u001b[0;36mv2_reduce\u001b[0;34m()\u001b[0m\n", "\u001b[0;32m 1 \u001b[0;31m\u001b[0;32mdef\u001b[0m \u001b[0mv2_reduce\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0m\u001b[0;32m 2 \u001b[0;31m \u001b[0mb_ij_c\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mu_hat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'b_ij'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'u_hat'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0m\u001b[0;32m 3 \u001b[0;31m \u001b[0mc_i\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msoftmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb_ij_c\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0m\u001b[0;32m----> 4 \u001b[0;31m \u001b[0ms_j\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mc_i\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mu_hat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0m\u001b[0;32m 5 \u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'h'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0ms_j\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0m\n", "ipdb> c_i.shape\n", "torch.Size([10, 40])\n", "ipdb> u_hat.shape\n", "torch.Size([10, 40, 2])\n", "ipdb> q\n" ] }, { "ename": "RuntimeError", "evalue": "The size of tensor a (50) must match the size of tensor b (10) at non-singleton dimension 1", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mget_ipython\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_line_magic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'debug'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m''\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_all\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcapsule_msg\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mv2_reduce\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m~/.pyenv/versions/3.6.1/lib/python3.6/site-packages/dgl-0.0.1-py3.6.egg/dgl/graph.py\u001b[0m in \u001b[0;36mupdate_all\u001b[0;34m(self, message_func, reduce_func, apply_node_func)\u001b[0m\n\u001b[1;32m 1138\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1139\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mALL\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mALL\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmessage_func\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1140\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrecv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mALL\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1141\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_nodes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mALL\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapply_node_func\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1142\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/.pyenv/versions/3.6.1/lib/python3.6/site-packages/dgl-0.0.1-py3.6.egg/dgl/graph.py\u001b[0m in \u001b[0;36mrecv\u001b[0;34m(self, u, reduce_func, apply_node_func)\u001b[0m\n\u001b[1;32m 944\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreduce_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 945\u001b[0m \u001b[0mreduce_func\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mBundledReduceFunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreduce_func\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 946\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_batch_recv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce_func\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 947\u001b[0m \u001b[0;31m# optional apply nodes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 948\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_nodes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapply_node_func\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/.pyenv/versions/3.6.1/lib/python3.6/site-packages/dgl-0.0.1-py3.6.egg/dgl/graph.py\u001b[0m in \u001b[0;36m_batch_recv\u001b[0;34m(self, v, reduce_func)\u001b[0m\n\u001b[1;32m 990\u001b[0m lambda key: _reshape_fn(in_msgs[key]), self._msg_frame.schemes)\n\u001b[1;32m 991\u001b[0m \u001b[0mreordered_v\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv_bkt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtousertensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 992\u001b[0;31m \u001b[0mnew_reprs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreduce_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdst_reprs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreshaped_in_msgs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 993\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 994\u001b[0m \u001b[0;31m# TODO: clear partial messages\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m\u001b[0m in \u001b[0;36mv2_reduce\u001b[0;34m(node, msg)\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mb_ij_c\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mu_hat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'b_ij'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'u_hat'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mc_i\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msoftmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb_ij_c\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0ms_j\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mc_i\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mu_hat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'h'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0ms_j\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (50) must match the size of tensor b (10) at non-singleton dimension 1" ] } ], "source": [ "%debug\n", "g.update_all(capsule_msg,v2_reduce,lambda x:x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g.set_e_repr({'u_hat':th.zeros()})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "b_ij=" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g.set_e_repr()" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "nxg=g.to_networkx()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.6", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.1" } }, "nbformat": 4, "nbformat_minor": 2 }