Commit e17c41c0 authored by VoVAllen's avatar VoVAllen Committed by Minjie Wang
Browse files

[Model][Tutorial] Fix capsule memory leak (#185)

* fix memory leak & Remove unnecessary initializer

* change confused name

* fix name

* Move func outside loop

* fix name inconsistency
parent 79fe09d3
...@@ -22,7 +22,6 @@ class DGLDigitCapsuleLayer(nn.Module): ...@@ -22,7 +22,6 @@ class DGLDigitCapsuleLayer(nn.Module):
device=self.device) device=self.device)
routing(u_hat, routing_num=3) routing(u_hat, routing_num=3)
out_nodes_feature = routing.g.nodes[routing.out_indx].data['v'] out_nodes_feature = routing.g.nodes[routing.out_indx].data['v']
routing.end()
# shape transformation is for further classification # shape transformation is for further classification
return out_nodes_feature.transpose(0, 1).unsqueeze(1).unsqueeze(4).squeeze(1) return out_nodes_feature.transpose(0, 1).unsqueeze(1).unsqueeze(4).squeeze(1)
......
...@@ -8,7 +8,7 @@ class DGLRoutingLayer(nn.Module): ...@@ -8,7 +8,7 @@ class DGLRoutingLayer(nn.Module):
def __init__(self, in_nodes, out_nodes, f_size, batch_size=0, device='cpu'): def __init__(self, in_nodes, out_nodes, f_size, batch_size=0, device='cpu'):
super(DGLRoutingLayer, self).__init__() super(DGLRoutingLayer, self).__init__()
self.batch_size = batch_size self.batch_size = batch_size
self.g = init_graph(in_nodes, out_nodes, f_size, device=device, batch_size=batch_size) self.g = init_graph(in_nodes, out_nodes, f_size, device=device)
self.in_nodes = in_nodes self.in_nodes = in_nodes
self.out_nodes = out_nodes self.out_nodes = out_nodes
self.in_indx = list(range(in_nodes)) self.in_indx = list(range(in_nodes))
...@@ -17,22 +17,26 @@ class DGLRoutingLayer(nn.Module): ...@@ -17,22 +17,26 @@ class DGLRoutingLayer(nn.Module):
def forward(self, u_hat, routing_num=1): def forward(self, u_hat, routing_num=1):
self.g.edata['u_hat'] = u_hat self.g.edata['u_hat'] = u_hat
for r in range(routing_num): batch_size = self.batch_size
# step 1 (line 4): normalize over out edges
in_edges = self.g.edata['b'].view(self.in_nodes, self.out_nodes) # step 2 (line 5)
self.g.edata['c'] = F.softmax(in_edges, dim=1).view(-1, 1) def cap_message(edges):
if batch_size:
return {'m': edges.data['c'].unsqueeze(1) * edges.data['u_hat']}
else:
return {'m': edges.data['c'] * edges.data['u_hat']}
self.g.register_message_func(cap_message)
def cap_message(edges): def cap_reduce(nodes):
if self.batch_size: return {'s': th.sum(nodes.mailbox['m'], dim=1)}
return {'m': edges.data['c'].unsqueeze(1) * edges.data['u_hat']}
else:
return {'m': edges.data['c'] * edges.data['u_hat']}
self.g.register_message_func(cap_message)
# step 2 (line 5) self.g.register_reduce_func(cap_reduce)
def cap_reduce(nodes):
return {'s': th.sum(nodes.mailbox['m'], dim=1)} for r in range(routing_num):
self.g.register_reduce_func(cap_reduce) # step 1 (line 4): normalize over out edges
edges_b = self.g.edata['b'].view(self.in_nodes, self.out_nodes)
self.g.edata['c'] = F.softmax(edges_b, dim=1).view(-1, 1)
# Execute step 1 & 2 # Execute step 1 & 2
self.g.update_all() self.g.update_all()
...@@ -50,13 +54,6 @@ class DGLRoutingLayer(nn.Module): ...@@ -50,13 +54,6 @@ class DGLRoutingLayer(nn.Module):
else: else:
self.g.edata['b'] = self.g.edata['b'] + (self.g.edata['u_hat'] * v).sum(dim=1, keepdim=True) self.g.edata['b'] = self.g.edata['b'] + (self.g.edata['u_hat'] * v).sum(dim=1, keepdim=True)
def end(self):
del self.g
# del self.g.edata['u_hat']
# del self.g.ndata['v']
# del self.g.ndata['s']
# del self.g.edata['b']
def squash(s, dim=1): def squash(s, dim=1):
sq = th.sum(s ** 2, dim=dim, keepdim=True) sq = th.sum(s ** 2, dim=dim, keepdim=True)
...@@ -65,8 +62,9 @@ def squash(s, dim=1): ...@@ -65,8 +62,9 @@ def squash(s, dim=1):
return s return s
def init_graph(in_nodes, out_nodes, f_size, device='cpu', batch_size=0): def init_graph(in_nodes, out_nodes, f_size, device='cpu'):
g = dgl.DGLGraph() g = dgl.DGLGraph()
g.set_n_initializer(dgl.frame.zero_initializer)
all_nodes = in_nodes + out_nodes all_nodes = in_nodes + out_nodes
g.add_nodes(all_nodes) g.add_nodes(all_nodes)
in_indx = list(range(in_nodes)) in_indx = list(range(in_nodes))
...@@ -75,10 +73,5 @@ def init_graph(in_nodes, out_nodes, f_size, device='cpu', batch_size=0): ...@@ -75,10 +73,5 @@ def init_graph(in_nodes, out_nodes, f_size, device='cpu', batch_size=0):
for u in in_indx: for u in in_indx:
g.add_edges(u, out_indx) g.add_edges(u, out_indx)
# init states
if batch_size:
g.ndata['v'] = th.zeros(all_nodes, batch_size, f_size).to(device)
else:
g.ndata['v'] = th.zeros(all_nodes, f_size).to(device)
g.edata['b'] = th.zeros(in_nodes * out_nodes, 1).to(device) g.edata['b'] = th.zeros(in_nodes * out_nodes, 1).to(device)
return g return g
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment