partition_utils.py 656 Bytes
Newer Older
1
2
3
4
5
6
from time import time

import numpy as np

from utils import arg_list

7
8
9
10
from dgl.transform import metis_partition
from dgl import backend as F
import dgl

11
def get_partition_list(g, psize):
12
13
14
15
16
17
18
    p_gs = metis_partition(g, psize)
    graphs = []
    for k, val in p_gs.items():
        nids = val.ndata[dgl.NID]
        nids = F.asnumpy(nids)
        graphs.append(nids)
    return graphs
19
20
21
22
23
24
25

def get_subgraph(g, par_arr, i, psize, batch_size):
    par_batch_ind_arr = [par_arr[s] for s in range(
        i * batch_size, (i + 1) * batch_size) if s < psize]
    g1 = g.subgraph(np.concatenate(
        par_batch_ind_arr).reshape(-1).astype(np.int64))
    return g1