{
"cells": [
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"env: DGL_LIBRARY_PATH=/data/jinjing/dgl/build\n",
"Populating the interactive namespace from numpy and matplotlib\n"
]
}
],
"source": [
"%env DGL_LIBRARY_PATH=/data/jinjing/dgl/build\n",
"%pylab inline\n",
"%config InlineBackend.figure_format = 'svg'\n",
"\n",
"import torch as th\n",
"import numpy as np\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import dgl\n",
"import networkx as nx\n",
"\n",
"\n",
"def update_edge(u, v, edge):\n",
" return {'b_ij': edge['b_ij'] + (v['h'] * edge['u_hat']).sum(dim=1)}\n",
"\n",
"\n",
"def v2_reduce(node, msg):\n",
" b_ij_c, u_hat = msg['b_ij'], msg['u_hat']\n",
" c_i = F.softmax(b_ij_c, dim=0)\n",
" s_j = (c_i.unsqueeze(2) * u_hat).sum(dim=1)\n",
" return {'s_j': s_j}\n",
"\n",
"\n",
"def capsule_msg(src, edge):\n",
" return {'b_ij': edge['b_ij'], 'h': src['h'], 'u_hat': edge['u_hat']}\n",
"\n",
"\n",
"def squash(s):\n",
" s = s['s_j']\n",
" msg_sq = th.sum(s**2, dim=1, keepdim=True)\n",
" msg = th.sqrt(msg_sq)\n",
" s = (msg_sq / (1.0 + msg_sq)) * (s / msg)\n",
" return {'h': s}\n",
"\n",
"\n",
"np.random.seed(10)\n",
"num_points = 10\n",
"num_points2 = 2\n",
"\n",
"theta_dist1 = np.pi * (np.random.rand(num_points) * 2)\n",
"\n",
"radius_1 = np.random.rand(num_points) * 2 + 3\n",
"points_1 = th.from_numpy(\n",
" np.array([np.cos(theta_dist1) * radius_1,\n",
" np.sin(theta_dist1) * radius_1])).float()"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"g = dgl.DGLGraph()\n",
"g.add_nodes(num_points + num_points2)\n",
"\n",
"for i in range(num_points):\n",
" for j in range(num_points2):\n",
" g.add_edge(i, j + num_points)\n",
"\n",
"nodes1 = list(range(num_points))\n",
"nodes2 = list(range(num_points, num_points + num_points2))"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"W = th.randn((num_points2, num_points, 2, 2))\n",
"u_hat = th.matmul(points_1.t().unsqueeze(1), W).squeeze(2).view(-1, 2)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([10, 1, 2])"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"points_1.t().unsqueeze(1).shape"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([20, 2])"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"u_hat.shape"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"g.set_n_repr({'h': th.cat([points_1, th.zeros((2, num_points2))], dim=1).t()})\n",
"g.set_e_repr({'b_ij': th.zeros(g.number_of_edges()).float()})\n",
"g.set_e_repr({'u_hat': u_hat})"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/jinjing/.pyenv/versions/3.6.1/lib/python3.6/site-packages/dgl-0.0.1-py3.6.egg/dgl/frame.py:256: UserWarning: Initializer is not set. Use zero initializer instead. To suppress this warning, use `set_initializer` to explicitly specify which initializer to use.\n",
" dgl_warning('Initializer is not set. Use zero initializer instead.'\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"iterations=10\n",
"selected_node_feats=[points_1.data.numpy().T]\n",
"for i in range(iterations):\n",
" g.update_all(capsule_msg, v2_reduce, squash)\n",
" g.update_edge(edge_func=update_edge)\n",
" \n",
" new_=g.get_n_repr()['h'][num_points:num_points+num_points2]\n",
" np_new=new_.data.numpy()\n",
" if i in (1,3,7,9):\n",
" selected_node_feats.append(np_new)\n",
" plt.scatter(*np_new.T,color='blue')\n",
" plt.scatter(*points_1, color='red')\n",
" plt.title(f\"Routing Number: {i}\")\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for feats in selected_node_feats:\n",
" plt.scatter(*feats.T,color='blue')\n",
" plt.scatter(*points_1, color='red')\n",
" plt.scatter(*feats[0].T, color='aqua')\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.6",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}