import dgl import torch as th import torch.nn as nn from torch.nn import functional as F from DGLRoutingLayer import DGLRoutingLayer g = dgl.DGLGraph() g.graph_data = {} in_nodes = 20 out_nodes = 10 g.graph_data['in_nodes']=in_nodes g.graph_data['out_nodes']=out_nodes all_nodes = in_nodes + out_nodes g.add_nodes(all_nodes) in_indx = list(range(in_nodes)) out_indx = list(range(in_nodes, in_nodes + out_nodes)) g.graph_data['in_indx']=in_indx g.graph_data['out_indx']=out_indx # add edges use edge broadcasting for u in out_indx: g.add_edges(in_indx, u) # init states f_size = 4 g.ndata['v'] = th.zeros(all_nodes, f_size) g.edata['u_hat'] = th.randn(in_nodes * out_nodes, f_size) g.edata['b'] = th.randn(in_nodes * out_nodes, 1) routing_layer = DGLRoutingLayer(g) entropy_list=[] for i in range(15): routing_layer() dist_matrix = g.edata['c'].view(in_nodes, out_nodes) entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=0) entropy_list.append(entropy.data.numpy()) std = dist_matrix.std(dim=0)