Unverified Commit d76af4d4 authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Return the ID mapping in graph partitioning. (#2857)



* return mapping.

* support heterogeneous graph.

* more test.

* fix lint.

* fix for diff backends.

* fix.

* fix.
Co-authored-by: default avatarZheng <dzzhen@3c22fba32af5.ant.amazon.com>
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 6b022d2f
...@@ -211,7 +211,7 @@ def load_partition_book(part_config, part_id, graph=None): ...@@ -211,7 +211,7 @@ def load_partition_book(part_config, part_id, graph=None):
part_metadata['graph_name'], ntypes, etypes part_metadata['graph_name'], ntypes, etypes
def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method="metis", def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method="metis",
reshuffle=True, balance_ntypes=None, balance_edges=False): reshuffle=True, balance_ntypes=None, balance_edges=False, return_mapping=False):
''' Partition a graph for distributed training and store the partitions on files. ''' Partition a graph for distributed training and store the partitions on files.
The partitioning occurs in three steps: 1) run a partition algorithm (e.g., Metis) to The partitioning occurs in three steps: 1) run a partition algorithm (e.g., Metis) to
...@@ -382,6 +382,22 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= ...@@ -382,6 +382,22 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
balance_edges : bool balance_edges : bool
Indicate whether to balance the edges in each partition. This argument is used by Indicate whether to balance the edges in each partition. This argument is used by
the Metis algorithm. the Metis algorithm.
return_mapping : bool
If `reshuffle=True`, this indicates to return the mapping between shuffled node/edge IDs
and the original node/edge IDs.
Returns
-------
Tensor or dict of tensors, optional
If `return_mapping=True`, return a 1D tensor that indicates the mapping between shuffled
node IDs and the original node IDs for a homogeneous graph; return a dict of 1D tensors
whose key is the node type and value is a 1D tensor mapping between shuffled node IDs and
the original node IDs for each node type for a heterogeneous graph.
Tensor or dict of tensors, optional
If `return_mapping=True`, return a 1D tensor that indicates the mapping between shuffled
edge IDs and the original edge IDs for a homogeneous graph; return a dict of 1D tensors
whose key is the edge type and value is a 1D tensor mapping between shuffled edge IDs and
the original edge IDs for each edge type for a heterogeneous graph.
Examples Examples
-------- --------
...@@ -440,21 +456,41 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= ...@@ -440,21 +456,41 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
parts[0] = sim_g.clone() parts[0] = sim_g.clone()
parts[0].ndata[NID] = parts[0].ndata['orig_id'] = F.arange(0, sim_g.number_of_nodes()) parts[0].ndata[NID] = parts[0].ndata['orig_id'] = F.arange(0, sim_g.number_of_nodes())
parts[0].edata[EID] = parts[0].edata['orig_id'] = F.arange(0, sim_g.number_of_edges()) parts[0].edata[EID] = parts[0].edata['orig_id'] = F.arange(0, sim_g.number_of_edges())
orig_nids = parts[0].ndata['orig_id']
orig_eids = parts[0].edata['orig_id']
else: else:
parts[0] = sim_g.clone() parts[0] = sim_g.clone()
parts[0].ndata[NID] = F.arange(0, sim_g.number_of_nodes()) orig_nids = parts[0].ndata[NID] = F.arange(0, sim_g.number_of_nodes())
parts[0].edata[EID] = F.arange(0, sim_g.number_of_edges()) orig_eids = parts[0].edata[EID] = F.arange(0, sim_g.number_of_edges())
parts[0].ndata['inner_node'] = F.ones((sim_g.number_of_nodes(),), F.int8, F.cpu()) parts[0].ndata['inner_node'] = F.ones((sim_g.number_of_nodes(),), F.int8, F.cpu())
parts[0].edata['inner_edge'] = F.ones((sim_g.number_of_edges(),), F.int8, F.cpu()) parts[0].edata['inner_edge'] = F.ones((sim_g.number_of_edges(),), F.int8, F.cpu())
elif part_method == 'metis': elif part_method in ('metis', 'random'):
sim_g, balance_ntypes = get_homogeneous(g, balance_ntypes) sim_g, balance_ntypes = get_homogeneous(g, balance_ntypes)
node_parts = metis_partition_assignment(sim_g, num_parts, balance_ntypes=balance_ntypes, if part_method == 'metis':
balance_edges=balance_edges) node_parts = metis_partition_assignment(sim_g, num_parts, balance_ntypes=balance_ntypes,
parts = partition_graph_with_halo(sim_g, node_parts, num_hops, reshuffle=reshuffle) balance_edges=balance_edges)
elif part_method == 'random': else:
sim_g, _ = get_homogeneous(g, balance_ntypes) node_parts = random_choice(num_parts, sim_g.number_of_nodes())
node_parts = random_choice(num_parts, sim_g.number_of_nodes()) parts, orig_nids, orig_eids = partition_graph_with_halo(sim_g, node_parts, num_hops,
parts = partition_graph_with_halo(sim_g, node_parts, num_hops, reshuffle=reshuffle) reshuffle=reshuffle)
is_hetero = len(g.etypes) > 1 or len(g.ntypes) > 1
if reshuffle and return_mapping and is_hetero:
# Get the type IDs
orig_ntype = F.gather_row(sim_g.ndata[NTYPE], orig_nids)
orig_etype = F.gather_row(sim_g.edata[ETYPE], orig_eids)
# Mapping between shuffled global IDs to original per-type IDs
orig_nids = F.gather_row(sim_g.ndata[NID], orig_nids)
orig_eids = F.gather_row(sim_g.edata[EID], orig_eids)
orig_nids = {ntype: F.boolean_mask(orig_nids, orig_ntype == g.get_ntype_id(ntype)) \
for ntype in g.ntypes}
orig_eids = {etype: F.boolean_mask(orig_eids, orig_etype == g.get_etype_id(etype)) \
for etype in g.etypes}
elif not reshuffle and not is_hetero and return_mapping:
orig_nids = F.arange(0, sim_g.number_of_nodes())
orig_eids = F.arange(0, sim_g.number_of_edges())
elif not reshuffle and return_mapping:
orig_nids = {ntype: F.arange(0, g.number_of_nodes(ntype)) for ntype in g.ntypes}
orig_eids = {etype: F.arange(0, g.number_of_edges(etype)) for etype in g.etypes}
else: else:
raise Exception('Unknown partitioning method: ' + part_method) raise Exception('Unknown partitioning method: ' + part_method)
...@@ -709,3 +745,6 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method= ...@@ -709,3 +745,6 @@ def partition_graph(g, graph_name, num_parts, out_path, num_hops=1, part_method=
num_cuts = 0 num_cuts = 0
print('There are {} edges in the graph and {} edge cuts for {} partitions.'.format( print('There are {} edges in the graph and {} edge cuts for {} partitions.'.format(
g.number_of_edges(), num_cuts, num_parts)) g.number_of_edges(), num_cuts, num_parts))
if return_mapping:
return orig_nids, orig_eids
...@@ -146,6 +146,12 @@ def partition_graph_with_halo(g, node_part, extra_cached_hops, reshuffle=False): ...@@ -146,6 +146,12 @@ def partition_graph_with_halo(g, node_part, extra_cached_hops, reshuffle=False):
-------- --------
a dict of DGLGraphs a dict of DGLGraphs
The key is the partition ID and the value is the DGLGraph of the partition. The key is the partition ID and the value is the DGLGraph of the partition.
Tensor
1D tensor that stores the mapping between the reshuffled node IDs and
the original node IDs if 'reshuffle=True'. Otherwise, return None.
Tensor
1D tensor that stores the mapping between the reshuffled edge IDs and
the original edge IDs if 'reshuffle=True'. Otherwise, return None.
''' '''
assert len(node_part) == g.number_of_nodes() assert len(node_part) == g.number_of_nodes()
if reshuffle: if reshuffle:
...@@ -194,7 +200,10 @@ def partition_graph_with_halo(g, node_part, extra_cached_hops, reshuffle=False): ...@@ -194,7 +200,10 @@ def partition_graph_with_halo(g, node_part, extra_cached_hops, reshuffle=False):
subg.edata['inner_edge'] = inner_edge subg.edata['inner_edge'] = inner_edge
subg_dict[i] = subg subg_dict[i] = subg
print('Construct subgraphs: {:.3f} seconds'.format(time.time() - start)) print('Construct subgraphs: {:.3f} seconds'.format(time.time() - start))
return subg_dict if reshuffle:
return subg_dict, orig_nids, orig_eids
else:
return subg_dict, None, None
def metis_partition_assignment(g, k, balance_ntypes=None, balance_edges=False): def metis_partition_assignment(g, k, balance_ntypes=None, balance_edges=False):
...@@ -342,6 +351,6 @@ def metis_partition(g, k, extra_cached_hops=0, reshuffle=False, ...@@ -342,6 +351,6 @@ def metis_partition(g, k, extra_cached_hops=0, reshuffle=False,
return None return None
# Then we split the original graph into parts based on the METIS partitioning results. # Then we split the original graph into parts based on the METIS partitioning results.
return partition_graph_with_halo(g, node_part, extra_cached_hops, reshuffle) return partition_graph_with_halo(g, node_part, extra_cached_hops, reshuffle)[0]
_init_api("dgl.partition") _init_api("dgl.partition")
...@@ -504,7 +504,7 @@ def get_nodeflow(g, node_ids, num_layers): ...@@ -504,7 +504,7 @@ def get_nodeflow(g, node_ids, num_layers):
def test_partition_with_halo(): def test_partition_with_halo():
g = create_large_graph(1000) g = create_large_graph(1000)
node_part = np.random.choice(4, g.number_of_nodes()) node_part = np.random.choice(4, g.number_of_nodes())
subgs = dgl.transform.partition_graph_with_halo(g, node_part, 2, reshuffle=True) subgs, _, _ = dgl.transform.partition_graph_with_halo(g, node_part, 2, reshuffle=True)
for part_id, subg in subgs.items(): for part_id, subg in subgs.items():
node_ids = np.nonzero(node_part == part_id)[0] node_ids = np.nonzero(node_part == part_id)[0]
lnode_ids = np.nonzero(F.asnumpy(subg.ndata['inner_node']))[0] lnode_ids = np.nonzero(F.asnumpy(subg.ndata['inner_node']))[0]
......
...@@ -45,7 +45,7 @@ def start_server(rank, tmpdir, disable_shared_mem, num_clients): ...@@ -45,7 +45,7 @@ def start_server(rank, tmpdir, disable_shared_mem, num_clients):
g.start() g.start()
def start_dist_dataloader(rank, tmpdir, num_server, drop_last): def start_dist_dataloader(rank, tmpdir, num_server, drop_last, orig_nid, orig_eid):
import dgl import dgl
import torch as th import torch as th
dgl.distributed.initialize("mp_ip_config.txt") dgl.distributed.initialize("mp_ip_config.txt")
...@@ -58,14 +58,8 @@ def start_dist_dataloader(rank, tmpdir, num_server, drop_last): ...@@ -58,14 +58,8 @@ def start_dist_dataloader(rank, tmpdir, num_server, drop_last):
train_nid = th.arange(num_nodes_to_sample) train_nid = th.arange(num_nodes_to_sample)
dist_graph = DistGraph("test_mp", gpb=gpb, part_config=tmpdir / 'test_sampling.json') dist_graph = DistGraph("test_mp", gpb=gpb, part_config=tmpdir / 'test_sampling.json')
orig_nid = F.arange(0, dist_graph.number_of_nodes())
orig_eid = F.arange(0, dist_graph.number_of_edges())
for i in range(num_server): for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i) part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
if 'orig_id' in part.ndata:
orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
if 'orig_id' in part.edata:
orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']
# Create sampler # Create sampler
sampler = NeighborSampler(dist_graph, [5, 10], sampler = NeighborSampler(dist_graph, [5, 10],
...@@ -117,12 +111,13 @@ def test_standalone(tmpdir): ...@@ -117,12 +111,13 @@ def test_standalone(tmpdir):
num_parts = 1 num_parts = 1
num_hops = 1 num_hops = 1
partition_graph(g, 'test_sampling', num_parts, tmpdir, orig_nid, orig_eid = partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=True) num_hops=num_hops, part_method='metis', reshuffle=True,
return_mapping=True)
os.environ['DGL_DIST_MODE'] = 'standalone' os.environ['DGL_DIST_MODE'] = 'standalone'
try: try:
start_dist_dataloader(0, tmpdir, 1, True) start_dist_dataloader(0, tmpdir, 1, True, orig_nid, orig_eid)
except Exception as e: except Exception as e:
print(e) print(e)
dgl.distributed.exit_client() # this is needed since there's two test here in one process dgl.distributed.exit_client() # this is needed since there's two test here in one process
...@@ -145,8 +140,9 @@ def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle): ...@@ -145,8 +140,9 @@ def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle):
num_parts = num_server num_parts = num_server
num_hops = 1 num_hops = 1
partition_graph(g, 'test_sampling', num_parts, tmpdir, orig_nid, orig_eid = partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=reshuffle) num_hops=num_hops, part_method='metis',
reshuffle=reshuffle, return_mapping=True)
pserver_list = [] pserver_list = []
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
...@@ -161,7 +157,7 @@ def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle): ...@@ -161,7 +157,7 @@ def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle):
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
os.environ['DGL_NUM_SAMPLER'] = str(num_workers) os.environ['DGL_NUM_SAMPLER'] = str(num_workers)
ptrainer = ctx.Process(target=start_dist_dataloader, args=( ptrainer = ctx.Process(target=start_dist_dataloader, args=(
0, tmpdir, num_server, drop_last)) 0, tmpdir, num_server, drop_last, orig_nid, orig_eid))
ptrainer.start() ptrainer.start()
time.sleep(1) time.sleep(1)
...@@ -169,7 +165,7 @@ def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle): ...@@ -169,7 +165,7 @@ def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last, reshuffle):
p.join() p.join()
ptrainer.join() ptrainer.join()
def start_node_dataloader(rank, tmpdir, num_server, num_workers): def start_node_dataloader(rank, tmpdir, num_server, num_workers, orig_nid, orig_eid):
import dgl import dgl
import torch as th import torch as th
dgl.distributed.initialize("mp_ip_config.txt") dgl.distributed.initialize("mp_ip_config.txt")
...@@ -182,12 +178,8 @@ def start_node_dataloader(rank, tmpdir, num_server, num_workers): ...@@ -182,12 +178,8 @@ def start_node_dataloader(rank, tmpdir, num_server, num_workers):
train_nid = th.arange(num_nodes_to_sample) train_nid = th.arange(num_nodes_to_sample)
dist_graph = DistGraph("test_mp", gpb=gpb, part_config=tmpdir / 'test_sampling.json') dist_graph = DistGraph("test_mp", gpb=gpb, part_config=tmpdir / 'test_sampling.json')
orig_nid = F.zeros((dist_graph.number_of_nodes(),), dtype=F.int64)
orig_eid = F.zeros((dist_graph.number_of_edges(),), dtype=F.int64)
for i in range(num_server): for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i) part, _, _, _, _, _, _ = load_partition(tmpdir / 'test_sampling.json', i)
orig_nid[part.ndata[dgl.NID]] = part.ndata['orig_id']
orig_eid[part.edata[dgl.EID]] = part.edata['orig_id']
# Create sampler # Create sampler
sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 10]) sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 10])
...@@ -239,8 +231,9 @@ def test_dataloader(tmpdir, num_server, num_workers, dataloader_type): ...@@ -239,8 +231,9 @@ def test_dataloader(tmpdir, num_server, num_workers, dataloader_type):
num_parts = num_server num_parts = num_server
num_hops = 1 num_hops = 1
partition_graph(g, 'test_sampling', num_parts, tmpdir, orig_nid, orig_eid = partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=True) num_hops=num_hops, part_method='metis',
reshuffle=True, return_mapping=True)
pserver_list = [] pserver_list = []
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
...@@ -257,7 +250,7 @@ def test_dataloader(tmpdir, num_server, num_workers, dataloader_type): ...@@ -257,7 +250,7 @@ def test_dataloader(tmpdir, num_server, num_workers, dataloader_type):
ptrainer_list = [] ptrainer_list = []
if dataloader_type == 'node': if dataloader_type == 'node':
p = ctx.Process(target=start_node_dataloader, args=( p = ctx.Process(target=start_node_dataloader, args=(
0, tmpdir, num_server, num_workers)) 0, tmpdir, num_server, num_workers, orig_nid, orig_eid))
p.start() p.start()
time.sleep(1) time.sleep(1)
ptrainer_list.append(p) ptrainer_list.append(p)
......
...@@ -141,11 +141,45 @@ def check_hetero_partition(hg, part_method): ...@@ -141,11 +141,45 @@ def check_hetero_partition(hg, part_method):
num_parts = 4 num_parts = 4
num_hops = 1 num_hops = 1
partition_graph(hg, 'test', num_parts, '/tmp/partition', num_hops=num_hops, orig_nids, orig_eids = partition_graph(hg, 'test', num_parts, '/tmp/partition', num_hops=num_hops,
part_method=part_method, reshuffle=True) part_method=part_method, reshuffle=True, return_mapping=True)
assert len(orig_nids) == len(hg.ntypes)
assert len(orig_eids) == len(hg.etypes)
for ntype in hg.ntypes:
assert len(orig_nids[ntype]) == hg.number_of_nodes(ntype)
for etype in hg.etypes:
assert len(orig_eids[etype]) == hg.number_of_edges(etype)
parts = [] parts = []
for i in range(num_parts): for i in range(num_parts):
part_g, node_feats, edge_feats, gpb, _, ntypes, etypes = load_partition('/tmp/partition/test.json', i) part_g, node_feats, edge_feats, gpb, _, ntypes, etypes = load_partition('/tmp/partition/test.json', i)
# Verify the mapping between the reshuffled IDs and the original IDs.
# These are partition-local IDs.
part_src_ids, part_dst_ids = part_g.edges()
# These are reshuffled global homogeneous IDs.
part_src_ids = F.gather_row(part_g.ndata[dgl.NID], part_src_ids)
part_dst_ids = F.gather_row(part_g.ndata[dgl.NID], part_dst_ids)
part_eids = part_g.edata[dgl.EID]
# These are reshuffled per-type IDs.
src_ntype_ids, part_src_ids = gpb.map_to_per_ntype(part_src_ids)
dst_ntype_ids, part_dst_ids = gpb.map_to_per_ntype(part_dst_ids)
etype_ids, part_eids = gpb.map_to_per_etype(part_eids)
# These are original per-type IDs.
for etype_id, etype in enumerate(hg.etypes):
part_src_ids1 = F.boolean_mask(part_src_ids, etype_ids == etype_id)
src_ntype_ids1 = F.boolean_mask(src_ntype_ids, etype_ids == etype_id)
part_dst_ids1 = F.boolean_mask(part_dst_ids, etype_ids == etype_id)
dst_ntype_ids1 = F.boolean_mask(dst_ntype_ids, etype_ids == etype_id)
part_eids1 = F.boolean_mask(part_eids, etype_ids == etype_id)
assert np.all(F.asnumpy(src_ntype_ids1 == src_ntype_ids1[0]))
assert np.all(F.asnumpy(dst_ntype_ids1 == dst_ntype_ids1[0]))
src_ntype = hg.ntypes[F.as_scalar(src_ntype_ids1[0])]
dst_ntype = hg.ntypes[F.as_scalar(dst_ntype_ids1[0])]
orig_src_ids1 = F.gather_row(orig_nids[src_ntype], part_src_ids1)
orig_dst_ids1 = F.gather_row(orig_nids[dst_ntype], part_dst_ids1)
orig_eids1 = F.gather_row(orig_eids[etype], part_eids1)
orig_eids2 = hg.edge_ids(orig_src_ids1, orig_dst_ids1, etype=etype)
assert len(orig_eids1) == len(orig_eids2)
assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2))
parts.append(part_g) parts.append(part_g)
verify_graph_feats(hg, part_g, node_feats) verify_graph_feats(hg, part_g, node_feats)
verify_hetero_graph(hg, parts) verify_hetero_graph(hg, parts)
...@@ -159,8 +193,8 @@ def check_partition(g, part_method, reshuffle): ...@@ -159,8 +193,8 @@ def check_partition(g, part_method, reshuffle):
num_parts = 4 num_parts = 4
num_hops = 2 num_hops = 2
partition_graph(g, 'test', num_parts, '/tmp/partition', num_hops=num_hops, orig_nids, orig_eids = partition_graph(g, 'test', num_parts, '/tmp/partition', num_hops=num_hops,
part_method=part_method, reshuffle=reshuffle) part_method=part_method, reshuffle=reshuffle, return_mapping=True)
part_sizes = [] part_sizes = []
for i in range(num_parts): for i in range(num_parts):
part_g, node_feats, edge_feats, gpb, _, ntypes, etypes = load_partition('/tmp/partition/test.json', i) part_g, node_feats, edge_feats, gpb, _, ntypes, etypes = load_partition('/tmp/partition/test.json', i)
...@@ -196,6 +230,18 @@ def check_partition(g, part_method, reshuffle): ...@@ -196,6 +230,18 @@ def check_partition(g, part_method, reshuffle):
assert F.dtype(local_edges1) in (F.int32, F.int64) assert F.dtype(local_edges1) in (F.int32, F.int64)
assert np.all(np.sort(F.asnumpy(local_edges)) == np.sort(F.asnumpy(local_edges1))) assert np.all(np.sort(F.asnumpy(local_edges)) == np.sort(F.asnumpy(local_edges1)))
# Verify the mapping between the reshuffled IDs and the original IDs.
part_src_ids, part_dst_ids = part_g.edges()
part_src_ids = F.gather_row(part_g.ndata[dgl.NID], part_src_ids)
part_dst_ids = F.gather_row(part_g.ndata[dgl.NID], part_dst_ids)
part_eids = part_g.edata[dgl.EID]
orig_src_ids = F.gather_row(orig_nids, part_src_ids)
orig_dst_ids = F.gather_row(orig_nids, part_dst_ids)
orig_eids1 = F.gather_row(orig_eids, part_eids)
orig_eids2 = g.edge_ids(orig_src_ids, orig_dst_ids)
assert F.shape(orig_eids1)[0] == F.shape(orig_eids2)[0]
assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2))
if reshuffle: if reshuffle:
part_g.ndata['feats'] = F.gather_row(g.ndata['feats'], part_g.ndata['orig_id']) part_g.ndata['feats'] = F.gather_row(g.ndata['feats'], part_g.ndata['orig_id'])
part_g.edata['feats'] = F.gather_row(g.edata['feats'], part_g.edata['orig_id']) part_g.edata['feats'] = F.gather_row(g.edata['feats'], part_g.edata['orig_id'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment