Commit 57f480f5 authored by Ivan Brugere's avatar Ivan Brugere Committed by Minjie Wang
Browse files

LSTM fix (#28)

Correctly handling lstm model creation
parent c50b90cf
...@@ -32,6 +32,7 @@ class NodeReduceModule(nn.Module): ...@@ -32,6 +32,7 @@ class NodeReduceModule(nn.Module):
self.fc = nn.ModuleList( self.fc = nn.ModuleList(
[nn.Linear(input_dim, num_hidden, bias=False) [nn.Linear(input_dim, num_hidden, bias=False)
for _ in range(num_heads)]) for _ in range(num_heads)])
self.attention = nn.ModuleList( self.attention = nn.ModuleList(
[nn.Linear(num_hidden * 2, 1, bias=False) for _ in range(num_heads)]) [nn.Linear(num_hidden * 2, 1, bias=False) for _ in range(num_heads)])
...@@ -61,14 +62,20 @@ class NodeReduceModule(nn.Module): ...@@ -61,14 +62,20 @@ class NodeReduceModule(nn.Module):
class NodeUpdateModule(nn.Module): class NodeUpdateModule(nn.Module):
def __init__(self, residual, fc, act, aggregator): def __init__(self, residual, fc, act, aggregator, lstm_size=0):
super(NodeUpdateModule, self).__init__() super(NodeUpdateModule, self).__init__()
self.residual = residual self.residual = residual
self.fc = fc self.fc = fc
self.act = act self.act = act
self.aggregator = aggregator self.aggregator = aggregator
if lstm_size:
self.lstm = nn.LSTM(input_size=lstm_size, hidden_size=lstm_size, num_layers=1)
else:
self.lstm=None
#print(fc[0].out_features)
def forward(self, node, msgs_repr): def forward(self, node, msgs_repr):
# apply residual connection and activation for each head # apply residual connection and activation for each head
for i in range(len(msgs_repr)): for i in range(len(msgs_repr)):
...@@ -80,26 +87,28 @@ class NodeUpdateModule(nn.Module): ...@@ -80,26 +87,28 @@ class NodeUpdateModule(nn.Module):
# aggregate multi-head results # aggregate multi-head results
h = self.aggregator(msgs_repr) h = self.aggregator(msgs_repr)
c0 = torch.zeros(h.shape) #print(h.shape)
if node['c'] is None: if self.lstm is not None:
c0 = torch.zeros(h.shape) c0 = torch.zeros(h.shape)
if node['c'] is None:
c0 = torch.zeros(h.shape)
else:
c0 = node['c']
if node['h_i'] is None:
h0 = torch.zeros(h.shape)
else:
h0 = node['h_i']
#add dimension to handle sequential (create sequence of length 1)
h, (h_i, c) = self.lstm(h.unsqueeze(0), (h0.unsqueeze(0), c0.unsqueeze(0)))
#remove sequential dim
h = torch.squeeze(h, 0)
h_i = torch.squeeze(h, 0)
c = torch.squeeze(c, 0)
return {'h': h, 'c':c, 'h_i':h_i}
else: else:
c0 = node['c'] return {'h': h, 'c':None, 'h_i':None}
if node['h_i'] is None:
h0 = torch.zeros(h.shape)
else:
h0 = node['h_i']
lstm = nn.LSTM(input_size=h.shape[1], hidden_size=h.shape[1], num_layers=1)
#add dimension to handle sequential (create sequence of length 1)
h, (h_i, c) = lstm(h.unsqueeze(0), (h0.unsqueeze(0), c0.unsqueeze(0)))
#remove sequential dim
h = torch.squeeze(h, 0)
h_i = torch.squeeze(h, 0)
c = torch.squeeze(c, 0)
return {'h': h, 'c':c, 'h_i':h_i}
class GeniePath(nn.Module): class GeniePath(nn.Module):
def __init__(self, num_layers, in_dim, num_hidden, num_classes, num_heads, def __init__(self, num_layers, in_dim, num_hidden, num_classes, num_heads,
...@@ -122,7 +131,7 @@ class GeniePath(nn.Module): ...@@ -122,7 +131,7 @@ class GeniePath(nn.Module):
attention_dropout)) attention_dropout))
self.update_layers.append( self.update_layers.append(
NodeUpdateModule(residual, self.reduce_layers[-1].fc, activation, NodeUpdateModule(residual, self.reduce_layers[-1].fc, activation,
lambda x: torch.cat(x, 1))) lambda x: torch.cat(x, 1), num_hidden * num_heads))
# projection # projection
self.reduce_layers.append( self.reduce_layers.append(
NodeReduceModule(num_hidden * num_heads, num_classes, 1, input_dropout, NodeReduceModule(num_hidden * num_heads, num_classes, 1, input_dropout,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment