"git@developer.sourcefind.cn:OpenDAS/deepspeed.git" did not exist on "b1ddea7fd94c52a7be76cec721d32d438681af83"
Unverified Commit b05cb84a authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[BUG] Fixes #1262 (#1286)

parent ddf8c858
...@@ -63,9 +63,9 @@ def main(args): ...@@ -63,9 +63,9 @@ def main(args):
feats = torch.arange(num_nodes) feats = torch.arange(num_nodes)
# edge type and normalization factor # edge type and normalization factor
edge_type = torch.from_numpy(data.edge_type) edge_type = torch.from_numpy(data.edge_type).long()
edge_norm = torch.from_numpy(data.edge_norm).unsqueeze(1) edge_norm = torch.from_numpy(data.edge_norm).unsqueeze(1).long()
labels = torch.from_numpy(labels).view(-1) labels = torch.from_numpy(labels).view(-1).long()
# check cuda # check cuda
use_cuda = args.gpu >= 0 and torch.cuda.is_available() use_cuda = args.gpu >= 0 and torch.cuda.is_available()
......
...@@ -141,7 +141,7 @@ def build_graph_from_triplets(num_nodes, num_rels, triplets): ...@@ -141,7 +141,7 @@ def build_graph_from_triplets(num_nodes, num_rels, triplets):
g.add_edges(src, dst) g.add_edges(src, dst)
norm = comp_deg_norm(g) norm = comp_deg_norm(g)
print("# nodes: {}, # edges: {}".format(num_nodes, len(src))) print("# nodes: {}, # edges: {}".format(num_nodes, len(src)))
return g, rel, norm return g, rel.astype('int64'), norm.astype('int64')
def build_test_graph(num_nodes, num_rels, edges): def build_test_graph(num_nodes, num_rels, edges):
src, rel, dst = edges.transpose() src, rel, dst = edges.transpose()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment