Commit b7187dd3 authored by Da Zheng's avatar Da Zheng
Browse files

Merge branch 'dist_part' of github.com:dmlc/dgl into dist_part

parents a324440f 3ccd973c
...@@ -9,6 +9,7 @@ import pyarrow ...@@ -9,6 +9,7 @@ import pyarrow
import pandas as pd import pandas as pd
import constants import constants
from pyarrow import csv from pyarrow import csv
from utils import read_json
def create_dgl_object(graph_name, num_parts, \ def create_dgl_object(graph_name, num_parts, \
schema, part_id, node_data, \ schema, part_id, node_data, \
...@@ -129,7 +130,8 @@ def create_dgl_object(graph_name, num_parts, \ ...@@ -129,7 +130,8 @@ def create_dgl_object(graph_name, num_parts, \
assert len(uniq_ids) == len(idx) assert len(uniq_ids) == len(idx)
# We get the edge list with their node IDs mapped to a contiguous ID range. # We get the edge list with their node IDs mapped to a contiguous ID range.
part_local_src_id, part_local_dst_id = np.split(inverse_idx[:len(shuffle_global_src_id) * 2], 2) part_local_src_id, part_local_dst_id = np.split(inverse_idx[:len(shuffle_global_src_id) * 2], 2)
compact_g = dgl.graph((part_local_src_id, part_local_dst_id))
compact_g = dgl.graph(data=(part_local_src_id, part_local_dst_id), num_nodes=len(idx))
compact_g.edata['orig_id'] = th.as_tensor(global_edge_id) compact_g.edata['orig_id'] = th.as_tensor(global_edge_id)
compact_g.edata[dgl.ETYPE] = th.as_tensor(etype_ids) compact_g.edata[dgl.ETYPE] = th.as_tensor(etype_ids)
compact_g.edata['inner_edge'] = th.ones( compact_g.edata['inner_edge'] = th.ones(
...@@ -232,6 +234,6 @@ def create_metadata_json(graph_name, num_nodes, num_edges, num_parts, node_map_v ...@@ -232,6 +234,6 @@ def create_metadata_json(graph_name, num_nodes, num_edges, num_parts, node_map_v
edge_feat_file = os.path.join(part_dir, "edge_feat.dgl") edge_feat_file = os.path.join(part_dir, "edge_feat.dgl")
part_graph_file = os.path.join(part_dir, "graph.dgl") part_graph_file = os.path.join(part_dir, "graph.dgl")
part_metadata['part-{}'.format(part_id)] = {'node_feats': node_feat_file, part_metadata['part-{}'.format(part_id)] = {'node_feats': node_feat_file,
'edge_feats': edge_feat_file, 'edge_feats': edge_feat_file,
'part_graph': part_graph_file} 'part_graph': part_graph_file}
return part_metadata return part_metadata
...@@ -50,7 +50,7 @@ def get_shuffle_global_nids(rank, world_size, global_nids_ranks, node_data): ...@@ -50,7 +50,7 @@ def get_shuffle_global_nids(rank, world_size, global_nids_ranks, node_data):
#form the outgoing message #form the outgoing message
send_nodes = [] send_nodes = []
for i in range(world_size): for i in range(world_size):
send_nodes.append(torch.Tensor(global_nids_ranks[i]).type(dtype=torch.int64)) send_nodes.append(torch.from_numpy(global_nids_ranks[i]).type(dtype=torch.int64))
#send-recieve messages #send-recieve messages
alltoallv_cpu(rank, world_size, recv_nodes, send_nodes) alltoallv_cpu(rank, world_size, recv_nodes, send_nodes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment