import os import random import dgl.function as fn import torch import time from partition_utils import * class ClusterIter(object): '''The partition sampler given a DGLGraph and partition number. The metis is used as the graph partition backend. ''' def __init__(self, dn, g, psize, batch_size, seed_nid): """Initialize the sampler. Paramters --------- dn : str The dataset name. g : DGLGraph The full graph of dataset psize: int The partition number batch_size: int The number of partitions in one batch seed_nid: np.ndarray The training nodes ids, used to extract the training graph """ self.psize = psize self.batch_size = batch_size # cache the partitions of known datasets&partition number if dn: fn = os.path.join('./datasets/', dn + '_{}.npy'.format(psize)) if os.path.exists(fn): self.par_li = np.load(fn, allow_pickle=True) else: os.makedirs('./datasets/', exist_ok=True) self.par_li = get_partition_list(g, psize) np.save(fn, self.par_li) else: self.par_li = get_partition_list(g, psize) par_list = [] for p in self.par_li: par = torch.Tensor(p) par_list.append(par) self.par_list = par_list # use one side normalization def get_norm(self, g): norm = 1. / g.in_degrees().float().unsqueeze(1) norm[torch.isinf(norm)] = 0 norm = norm.to(self.g.ndata['feat'].device) return norm def __len__(self): return self.psize def __getitem__(self, idx): return self.par_li[idx] def subgraph_collate_fn(g, batch): nids = np.concatenate(batch).reshape(-1).astype(np.int64) g1 = g.subgraph(nids) nid = g1.ndata[dgl.NID] g1.ndata['feat'] = g.ndata['feat'][nid] g1.ndata['labels'] = g.ndata['labels'][nid] g1.ndata['train_mask'] = g.ndata['train_mask'][nid] g1.create_formats_() return g1