Commit b7187dd3 authored by Da Zheng's avatar Da Zheng
Browse files

Merge branch 'dist_part' of github.com:dmlc/dgl into dist_part

parents a324440f 3ccd973c
...@@ -9,6 +9,7 @@ import pyarrow ...@@ -9,6 +9,7 @@ import pyarrow
import pandas as pd import pandas as pd
import constants import constants
from pyarrow import csv from pyarrow import csv
from utils import read_json
def create_dgl_object(graph_name, num_parts, \ def create_dgl_object(graph_name, num_parts, \
schema, part_id, node_data, \ schema, part_id, node_data, \
...@@ -129,7 +130,8 @@ def create_dgl_object(graph_name, num_parts, \ ...@@ -129,7 +130,8 @@ def create_dgl_object(graph_name, num_parts, \
assert len(uniq_ids) == len(idx) assert len(uniq_ids) == len(idx)
# We get the edge list with their node IDs mapped to a contiguous ID range. # We get the edge list with their node IDs mapped to a contiguous ID range.
part_local_src_id, part_local_dst_id = np.split(inverse_idx[:len(shuffle_global_src_id) * 2], 2) part_local_src_id, part_local_dst_id = np.split(inverse_idx[:len(shuffle_global_src_id) * 2], 2)
compact_g = dgl.graph((part_local_src_id, part_local_dst_id))
compact_g = dgl.graph(data=(part_local_src_id, part_local_dst_id), num_nodes=len(idx))
compact_g.edata['orig_id'] = th.as_tensor(global_edge_id) compact_g.edata['orig_id'] = th.as_tensor(global_edge_id)
compact_g.edata[dgl.ETYPE] = th.as_tensor(etype_ids) compact_g.edata[dgl.ETYPE] = th.as_tensor(etype_ids)
compact_g.edata['inner_edge'] = th.ones( compact_g.edata['inner_edge'] = th.ones(
......
...@@ -50,7 +50,7 @@ def get_shuffle_global_nids(rank, world_size, global_nids_ranks, node_data): ...@@ -50,7 +50,7 @@ def get_shuffle_global_nids(rank, world_size, global_nids_ranks, node_data):
#form the outgoing message #form the outgoing message
send_nodes = [] send_nodes = []
for i in range(world_size): for i in range(world_size):
send_nodes.append(torch.Tensor(global_nids_ranks[i]).type(dtype=torch.int64)) send_nodes.append(torch.from_numpy(global_nids_ranks[i]).type(dtype=torch.int64))
#send-recieve messages #send-recieve messages
alltoallv_cpu(rank, world_size, recv_nodes, send_nodes) alltoallv_cpu(rank, world_size, recv_nodes, send_nodes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment