Unverified Commit b05cb84a authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[BUG] Fixes #1262 (#1286)

parent ddf8c858
......@@ -63,9 +63,9 @@ def main(args):
feats = torch.arange(num_nodes)
# edge type and normalization factor
edge_type = torch.from_numpy(data.edge_type)
edge_norm = torch.from_numpy(data.edge_norm).unsqueeze(1)
labels = torch.from_numpy(labels).view(-1)
edge_type = torch.from_numpy(data.edge_type).long()
edge_norm = torch.from_numpy(data.edge_norm).unsqueeze(1).long()
labels = torch.from_numpy(labels).view(-1).long()
# check cuda
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
......
......@@ -141,7 +141,7 @@ def build_graph_from_triplets(num_nodes, num_rels, triplets):
g.add_edges(src, dst)
norm = comp_deg_norm(g)
print("# nodes: {}, # edges: {}".format(num_nodes, len(src)))
return g, rel, norm
return g, rel.astype('int64'), norm.astype('int64')
def build_test_graph(num_nodes, num_rels, edges):
src, rel, dst = edges.transpose()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment