Unverified Commit d3c24cc2 authored by Lingfan Yu's avatar Lingfan Yu Committed by GitHub
Browse files

[BugFix] Fix bug in RGCN data processing and use index_select to improve speed (#429)

* use index_select instead of __getitem__

* fix bug in dataset processing

* fix edge_type shape bug

* comments
parent fb4246e5
...@@ -89,10 +89,10 @@ class RGCNBasisLayer(RGCNLayer): ...@@ -89,10 +89,10 @@ class RGCNBasisLayer(RGCNLayer):
# an embedding lookup using source node id # an embedding lookup using source node id
embed = weight.view(-1, self.out_feat) embed = weight.view(-1, self.out_feat)
index = edges.data['type'] * self.in_feat + edges.src['id'] index = edges.data['type'] * self.in_feat + edges.src['id']
return {'msg': embed[index] * edges.data['norm']} return {'msg': embed.index_select(0, index) * edges.data['norm']}
else: else:
def msg_func(edges): def msg_func(edges):
w = weight[edges.data['type']] w = weight.index_select(0, edges.data['type'])
msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze() msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze()
msg = msg * edges.data['norm'] msg = msg * edges.data['norm']
return {'msg': msg} return {'msg': msg}
...@@ -119,7 +119,7 @@ class RGCNBlockLayer(RGCNLayer): ...@@ -119,7 +119,7 @@ class RGCNBlockLayer(RGCNLayer):
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu')) nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
def msg_func(self, edges): def msg_func(self, edges):
weight = self.weight[edges.data['type']].view( weight = self.weight.index_select(0, edges.data['type']).view(
-1, self.submat_in, self.submat_out) -1, self.submat_in, self.submat_out)
node = edges.src['h'].view(-1, 1, self.submat_in) node = edges.src['h'].view(-1, 1, self.submat_in)
msg = torch.bmm(node, weight).view(-1, self.out_feat) msg = torch.bmm(node, weight).view(-1, self.out_feat)
......
...@@ -149,7 +149,7 @@ def main(args): ...@@ -149,7 +149,7 @@ def main(args):
# set node/edge feature # set node/edge feature
node_id = torch.from_numpy(node_id).view(-1, 1) node_id = torch.from_numpy(node_id).view(-1, 1)
edge_type = torch.from_numpy(edge_type).view(-1, 1) edge_type = torch.from_numpy(edge_type)
node_norm = torch.from_numpy(node_norm).view(-1, 1) node_norm = torch.from_numpy(node_norm).view(-1, 1)
data, labels = torch.from_numpy(data), torch.from_numpy(labels) data, labels = torch.from_numpy(data), torch.from_numpy(labels)
deg = g.in_degrees(range(g.number_of_nodes())).float().view(-1, 1) deg = g.in_degrees(range(g.number_of_nodes())).float().view(-1, 1)
......
...@@ -410,8 +410,10 @@ def _load_data(dataset_str='aifb', dataset_path=None): ...@@ -410,8 +410,10 @@ def _load_data(dataset_str='aifb', dataset_path=None):
dst = nodes_dict[o] dst = nodes_dict[o]
assert src < num_node and dst < num_node assert src < num_node and dst < num_node
rel = relations_dict[p] rel = relations_dict[p]
edge_list.append((src, dst, 2 * rel)) # relation id 0 is self-relation, so others should start with 1
edge_list.append((dst, src, 2 * rel + 1)) edge_list.append((src, dst, 2 * rel + 1))
# reverse relation
edge_list.append((dst, src, 2 * rel + 2))
# sort indices by destination # sort indices by destination
edge_list = sorted(edge_list, key=lambda x: (x[1], x[0], x[2])) edge_list = sorted(edge_list, key=lambda x: (x[1], x[0], x[2]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment