"git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "ad07d254fb5cefadf8dcb6020b24fb0baee4e936"
Unverified Commit 53835bdb authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Dist][BugFix] enable sampling on bipartite (#4014)

* [Dist][BugFix] enable sampling on bipartite

* add comments for tests
parent 3dd54d5c
...@@ -367,6 +367,12 @@ class GraphPartitionBook(ABC): ...@@ -367,6 +367,12 @@ class GraphPartitionBook(ABC):
"""Get the list of edge types """Get the list of edge types
""" """
@property
def is_homogeneous(self):
"""check if homogeneous
"""
return not(len(self.etypes) > 1 or len(self.ntypes) > 1)
def map_to_per_ntype(self, ids): def map_to_per_ntype(self, ids):
"""Map homogeneous node IDs to type-wise IDs and node types. """Map homogeneous node IDs to type-wise IDs and node types.
......
...@@ -504,7 +504,7 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No ...@@ -504,7 +504,7 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
return _sample_etype_neighbors(local_g, partition_book, local_nids, return _sample_etype_neighbors(local_g, partition_book, local_nids,
etype_field, fanout, edge_dir, prob, replace) etype_field, fanout, edge_dir, prob, replace)
frontier = _distributed_access(g, nodes, issue_remote_req, local_access) frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
if len(gpb.etypes) > 1: if not gpb.is_homogeneous:
return _frontier_to_heterogeneous_graph(g, frontier, gpb) return _frontier_to_heterogeneous_graph(g, frontier, gpb)
else: else:
return frontier return frontier
...@@ -559,7 +559,7 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False): ...@@ -559,7 +559,7 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
A sampled subgraph containing only the sampled neighboring edges. It is on CPU. A sampled subgraph containing only the sampled neighboring edges. It is on CPU.
""" """
gpb = g.get_partition_book() gpb = g.get_partition_book()
if len(gpb.etypes) > 1: if not gpb.is_homogeneous:
assert isinstance(nodes, dict) assert isinstance(nodes, dict)
homo_nids = [] homo_nids = []
for ntype in nodes: for ntype in nodes:
...@@ -581,7 +581,7 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False): ...@@ -581,7 +581,7 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
return _sample_neighbors(local_g, partition_book, local_nids, return _sample_neighbors(local_g, partition_book, local_nids,
fanout, edge_dir, prob, replace) fanout, edge_dir, prob, replace)
frontier = _distributed_access(g, nodes, issue_remote_req, local_access) frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
if len(gpb.etypes) > 1: if not gpb.is_homogeneous:
return _frontier_to_heterogeneous_graph(g, frontier, gpb) return _frontier_to_heterogeneous_graph(g, frontier, gpb)
else: else:
return frontier return frontier
......
...@@ -553,6 +553,268 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server): ...@@ -553,6 +553,268 @@ def check_rpc_hetero_etype_sampling_empty_shuffle(tmpdir, num_server):
assert block.number_of_edges() == 0 assert block.number_of_edges() == 0
assert len(block.etypes) == len(g.etypes) assert len(block.etypes) == len(g.etypes)
def create_random_bipartite():
g = dgl.rand_bipartite('user', 'buys', 'game', 500, 1000, 1000)
g.nodes['user'].data['feat'] = F.ones(
(g.num_nodes('user'), 10), F.float32, F.cpu())
g.nodes['game'].data['feat'] = F.ones(
(g.num_nodes('game'), 10), F.float32, F.cpu())
return g
def start_bipartite_sample_client(rank, tmpdir, disable_shared_mem, nodes):
gpb = None
if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition(
tmpdir / 'test_sampling.json', rank)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("test_sampling", gpb=gpb)
assert 'feat' in dist_graph.nodes['user'].data
assert 'feat' in dist_graph.nodes['game'].data
if gpb is None:
gpb = dist_graph.get_partition_book()
sampled_graph = sample_neighbors(dist_graph, nodes, 3)
block = dgl.to_block(sampled_graph, nodes)
if sampled_graph.num_edges() > 0:
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
dgl.distributed.exit_client()
return block, gpb
def start_bipartite_etype_sample_client(rank, tmpdir, disable_shared_mem, fanout=3,
nodes={}):
gpb = None
if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition(
tmpdir / 'test_sampling.json', rank)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("test_sampling", gpb=gpb)
assert 'feat' in dist_graph.nodes['user'].data
assert 'feat' in dist_graph.nodes['game'].data
if dist_graph.local_partition is not None:
# Check whether etypes are sorted in dist_graph
local_g = dist_graph.local_partition
local_nids = np.arange(local_g.num_nodes())
for lnid in local_nids:
leids = local_g.in_edges(lnid, form='eid')
letids = F.asnumpy(local_g.edata[dgl.ETYPE][leids])
_, idices = np.unique(letids, return_index=True)
assert np.all(idices[:-1] <= idices[1:])
if gpb is None:
gpb = dist_graph.get_partition_book()
sampled_graph = sample_etype_neighbors(
dist_graph, nodes, dgl.ETYPE, fanout)
block = dgl.to_block(sampled_graph, nodes)
if sampled_graph.num_edges() > 0:
block.edata[dgl.EID] = sampled_graph.edata[dgl.EID]
dgl.distributed.exit_client()
return block, gpb
def check_rpc_bipartite_sampling_empty(tmpdir, num_server):
"""sample on bipartite via sample_neighbors() which yields empty sample results"""
generate_ip_config("rpc_ip_config.txt", num_server, num_server)
g = create_random_bipartite()
num_parts = num_server
num_hops = 1
orig_nids, _ = partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=True, return_mapping=True)
pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(
i, tmpdir, num_server > 1, 'test_sampling'))
p.start()
time.sleep(1)
pserver_list.append(p)
deg = get_degrees(g, orig_nids['game'], 'game')
empty_nids = F.nonzero_1d(deg == 0)
block, _ = start_bipartite_sample_client(0, tmpdir, num_server > 1,
nodes={'game': empty_nids, 'user': [1]})
print("Done sampling")
for p in pserver_list:
p.join()
assert block.number_of_edges() == 0
assert len(block.etypes) == len(g.etypes)
def check_rpc_bipartite_sampling_shuffle(tmpdir, num_server):
"""sample on bipartite via sample_neighbors() which yields non-empty sample results"""
generate_ip_config("rpc_ip_config.txt", num_server, num_server)
g = create_random_bipartite()
num_parts = num_server
num_hops = 1
orig_nids, _ = partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=True, return_mapping=True)
pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(
i, tmpdir, num_server > 1, 'test_sampling'))
p.start()
time.sleep(1)
pserver_list.append(p)
deg = get_degrees(g, orig_nids['game'], 'game')
nids = F.nonzero_1d(deg > 0)
block, gpb = start_bipartite_sample_client(0, tmpdir, num_server > 1,
nodes={'game': nids, 'user': [0]})
print("Done sampling")
for p in pserver_list:
p.join()
orig_nid_map = {ntype: F.zeros(
(g.number_of_nodes(ntype),), dtype=F.int64) for ntype in g.ntypes}
orig_eid_map = {etype: F.zeros(
(g.number_of_edges(etype),), dtype=F.int64) for etype in g.etypes}
for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(
tmpdir / 'test_sampling.json', i)
ntype_ids, type_nids = gpb.map_to_per_ntype(part.ndata[dgl.NID])
for ntype_id, ntype in enumerate(g.ntypes):
idx = ntype_ids == ntype_id
F.scatter_row_inplace(orig_nid_map[ntype], F.boolean_mask(type_nids, idx),
F.boolean_mask(part.ndata['orig_id'], idx))
etype_ids, type_eids = gpb.map_to_per_etype(part.edata[dgl.EID])
for etype_id, etype in enumerate(g.etypes):
idx = etype_ids == etype_id
F.scatter_row_inplace(orig_eid_map[etype], F.boolean_mask(type_eids, idx),
F.boolean_mask(part.edata['orig_id'], idx))
for src_type, etype, dst_type in block.canonical_etypes:
src, dst = block.edges(etype=etype)
# These are global Ids after shuffling.
shuffled_src = F.gather_row(
block.srcnodes[src_type].data[dgl.NID], src)
shuffled_dst = F.gather_row(
block.dstnodes[dst_type].data[dgl.NID], dst)
shuffled_eid = block.edges[etype].data[dgl.EID]
orig_src = F.asnumpy(F.gather_row(
orig_nid_map[src_type], shuffled_src))
orig_dst = F.asnumpy(F.gather_row(
orig_nid_map[dst_type], shuffled_dst))
orig_eid = F.asnumpy(F.gather_row(orig_eid_map[etype], shuffled_eid))
# Check the node Ids and edge Ids.
orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
assert np.all(F.asnumpy(orig_src1) == orig_src)
assert np.all(F.asnumpy(orig_dst1) == orig_dst)
def check_rpc_bipartite_etype_sampling_empty(tmpdir, num_server):
"""sample on bipartite via sample_etype_neighbors() which yields empty sample results"""
generate_ip_config("rpc_ip_config.txt", num_server, num_server)
g = create_random_bipartite()
num_parts = num_server
num_hops = 1
orig_nids, _ = partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=True, return_mapping=True)
pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(
i, tmpdir, num_server > 1, 'test_sampling'))
p.start()
time.sleep(1)
pserver_list.append(p)
deg = get_degrees(g, orig_nids['game'], 'game')
empty_nids = F.nonzero_1d(deg == 0)
block, gpb = start_bipartite_etype_sample_client(0, tmpdir, num_server > 1,
nodes={'game': empty_nids, 'user': [1]})
print("Done sampling")
for p in pserver_list:
p.join()
assert block is not None
assert block.number_of_edges() == 0
assert len(block.etypes) == len(g.etypes)
def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):
"""sample on bipartite via sample_etype_neighbors() which yields non-empty sample results"""
generate_ip_config("rpc_ip_config.txt", num_server, num_server)
g = create_random_bipartite()
num_parts = num_server
num_hops = 1
orig_nids, _ = partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=True, return_mapping=True)
pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(
i, tmpdir, num_server > 1, 'test_sampling'))
p.start()
time.sleep(1)
pserver_list.append(p)
fanout = 3
deg = get_degrees(g, orig_nids['game'], 'game')
nids = F.nonzero_1d(deg > 0)
block, gpb = start_bipartite_etype_sample_client(0, tmpdir, num_server > 1, fanout,
nodes={'game': nids, 'user': [0]})
print("Done sampling")
for p in pserver_list:
p.join()
orig_nid_map = {ntype: F.zeros(
(g.number_of_nodes(ntype),), dtype=F.int64) for ntype in g.ntypes}
orig_eid_map = {etype: F.zeros(
(g.number_of_edges(etype),), dtype=F.int64) for etype in g.etypes}
for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(
tmpdir / 'test_sampling.json', i)
ntype_ids, type_nids = gpb.map_to_per_ntype(part.ndata[dgl.NID])
for ntype_id, ntype in enumerate(g.ntypes):
idx = ntype_ids == ntype_id
F.scatter_row_inplace(orig_nid_map[ntype], F.boolean_mask(type_nids, idx),
F.boolean_mask(part.ndata['orig_id'], idx))
etype_ids, type_eids = gpb.map_to_per_etype(part.edata[dgl.EID])
for etype_id, etype in enumerate(g.etypes):
idx = etype_ids == etype_id
F.scatter_row_inplace(orig_eid_map[etype], F.boolean_mask(type_eids, idx),
F.boolean_mask(part.edata['orig_id'], idx))
for src_type, etype, dst_type in block.canonical_etypes:
src, dst = block.edges(etype=etype)
# These are global Ids after shuffling.
shuffled_src = F.gather_row(
block.srcnodes[src_type].data[dgl.NID], src)
shuffled_dst = F.gather_row(
block.dstnodes[dst_type].data[dgl.NID], dst)
shuffled_eid = block.edges[etype].data[dgl.EID]
orig_src = F.asnumpy(F.gather_row(
orig_nid_map[src_type], shuffled_src))
orig_dst = F.asnumpy(F.gather_row(
orig_nid_map[dst_type], shuffled_dst))
orig_eid = F.asnumpy(F.gather_row(orig_eid_map[etype], shuffled_eid))
# Check the node Ids and edge Ids.
orig_src1, orig_dst1 = g.find_edges(orig_eid, etype=etype)
assert np.all(F.asnumpy(orig_src1) == orig_src)
assert np.all(F.asnumpy(orig_dst1) == orig_dst)
# Wait non shared memory graph store # Wait non shared memory graph store
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now') @unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
...@@ -569,6 +831,10 @@ def test_rpc_sampling_shuffle(num_server): ...@@ -569,6 +831,10 @@ def test_rpc_sampling_shuffle(num_server):
check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), num_server) check_rpc_hetero_sampling_empty_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), num_server) check_rpc_hetero_etype_sampling_shuffle(Path(tmpdirname), num_server)
check_rpc_hetero_etype_sampling_empty_shuffle(Path(tmpdirname), num_server) check_rpc_hetero_etype_sampling_empty_shuffle(Path(tmpdirname), num_server)
check_rpc_bipartite_sampling_empty(Path(tmpdirname), num_server)
check_rpc_bipartite_sampling_shuffle(Path(tmpdirname), num_server)
check_rpc_bipartite_etype_sampling_empty(Path(tmpdirname), num_server)
check_rpc_bipartite_etype_sampling_shuffle(Path(tmpdirname), num_server)
def check_standalone_sampling(tmpdir, reshuffle): def check_standalone_sampling(tmpdir, reshuffle):
g = CitationGraphDataset("cora")[0] g = CitationGraphDataset("cora")[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment