"git@developer.sourcefind.cn:OpenDAS/pytorch-encoding.git" did not exist on "308447157af829f9d842bd0a2a5c035e0b0ecd29"
Unverified Commit 4dd16f5d authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[BugFix] enable DistGraph.find_edge() works with str or tuple of str (#4319)

parent 44b68641
...@@ -1132,7 +1132,8 @@ class DistGraph: ...@@ -1132,7 +1132,8 @@ class DistGraph:
gpb = self.get_partition_book() gpb = self.get_partition_book()
if len(gpb.etypes) > 1: if len(gpb.etypes) > 1:
# if etype is a canonical edge type (str, str, str), extract the edge type # if etype is a canonical edge type (str, str, str), extract the edge type
if len(etype) == 3: if isinstance(etype, tuple):
assert len(etype) == 3, 'Invalid canonical etype: {}'.format(etype)
etype = etype[1] etype = etype[1]
edges = gpb.map_to_homo_eid(edges, etype) edges = gpb.map_to_homo_eid(edges, etype)
src, dst = dist_find_edges(self, edges) src, dst = dist_find_edges(self, edges)
......
...@@ -160,9 +160,9 @@ def check_rpc_find_edges_shuffle(tmpdir, num_server): ...@@ -160,9 +160,9 @@ def check_rpc_find_edges_shuffle(tmpdir, num_server):
def create_random_hetero(dense=False, empty=False): def create_random_hetero(dense=False, empty=False):
num_nodes = {'n1': 210, 'n2': 200, 'n3': 220} if dense else \ num_nodes = {'n1': 210, 'n2': 200, 'n3': 220} if dense else \
{'n1': 1010, 'n2': 1000, 'n3': 1020} {'n1': 1010, 'n2': 1000, 'n3': 1020}
etypes = [('n1', 'r1', 'n2'), etypes = [('n1', 'r12', 'n2'),
('n1', 'r2', 'n3'), ('n1', 'r13', 'n3'),
('n2', 'r3', 'n3')] ('n2', 'r23', 'n3')]
edges = {} edges = {}
random.seed(42) random.seed(42)
for etype in etypes: for etype in etypes:
...@@ -195,9 +195,18 @@ def check_rpc_hetero_find_edges_shuffle(tmpdir, num_server): ...@@ -195,9 +195,18 @@ def check_rpc_hetero_find_edges_shuffle(tmpdir, num_server):
time.sleep(1) time.sleep(1)
pserver_list.append(p) pserver_list.append(p)
eids = F.tensor(np.random.randint(g.number_of_edges('r1'), size=100)) eids = F.tensor(np.random.randint(g.num_edges('r12'), size=100))
u, v = g.find_edges(orig_eid['r1'][eids], etype='r1') expect_except = False
du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids, etype='r1') try:
_, _ = g.find_edges(orig_eid['r12'][eids], etype=('n1', 'r12'))
except:
expect_except = True
assert expect_except
u, v = g.find_edges(orig_eid['r12'][eids], etype='r12')
u1, v1 = g.find_edges(orig_eid['r12'][eids], etype=('n1', 'r12', 'n2'))
assert F.array_equal(u, u1)
assert F.array_equal(v, v1)
du, dv = start_find_edges_client(0, tmpdir, num_server > 1, eids, etype='r12')
du = orig_nid['n1'][du] du = orig_nid['n1'][du]
dv = orig_nid['n2'][dv] dv = orig_nid['n2'][dv]
assert F.array_equal(u, du) assert F.array_equal(u, du)
...@@ -488,9 +497,9 @@ def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server, etype_sorted=Fal ...@@ -488,9 +497,9 @@ def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server, etype_sorted=Fal
for p in pserver_list: for p in pserver_list:
p.join() p.join()
src, dst = block.edges(etype=('n1', 'r2', 'n3')) src, dst = block.edges(etype=('n1', 'r13', 'n3'))
assert len(src) == 18 assert len(src) == 18
src, dst = block.edges(etype=('n2', 'r3', 'n3')) src, dst = block.edges(etype=('n2', 'r23', 'n3'))
assert len(src) == 18 assert len(src) == 18
orig_nid_map = {ntype: F.zeros((g.number_of_nodes(ntype),), dtype=F.int64) for ntype in g.ntypes} orig_nid_map = {ntype: F.zeros((g.number_of_nodes(ntype),), dtype=F.int64) for ntype in g.ntypes}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment