Unverified Commit d70ca6eb authored by lt610's avatar lt610 Committed by GitHub
Browse files

[Example] graphsaint (#2792)



* graphsaint

* graphsaint

* graphsaint

* graphsaint

* fixed the model

* fixed some bugs

* fixed the computing of normalization and updated the results

* fixed some bugs and updated the results

* Update utils.py

* Update train_sampling.py

* Update train_sampling.py
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 103444c5
# GraphSAINT
This DGL example implements the paper: GraphSAINT: Graph Sampling Based Inductive Learning Method.
Paper link: https://arxiv.org/abs/1907.04931
Author's code: https://github.com/GraphSAINT/GraphSAINT
Contributor: Liu Tang ([@lt610](https://github.com/lt610))
## Dependencies
- Python 3.7.0
- PyTorch 1.6.0
- NumPy 1.19.2
- Scikit-learn 0.23.2
- DGL 0.5.3
## Dataset
All datasets used are provided by Author's [code](https://github.com/GraphSAINT/GraphSAINT). They are available in [Google Drive](https://drive.google.com/drive/folders/1zycmmDES39zVlbVCYs88JTJ1Wm5FbfLz) (alternatively, [Baidu Wangpan (code: f1ao)](https://pan.baidu.com/s/1SOb0SiSAXavwAcNqkttwcg#list/path=%2F)). Once you download the datasets, you need to rename graphsaintdata to data. Dataset summary("m" stands for multi-label classification, and "s" for single-label.):
| Dataset | Nodes | Edges | Degree | Feature | Classes | Train/Val/Test |
| :-: | :-: | :-: | :-: | :-: | :-: | :-: |
| PPI | 14,755 | 225,270 | 15 | 50 | 121(m) | 0.66/0.12/0.22 |
| Flickr | 89,250 | 899,756 | 10 | 500 | 7(s) | 0.50/0.25/0.25 |
## Minibatch training
Run with following:
```bash
python train_sampling.py --gpu 0 --dataset ppi --sampler node --node-budget 6000 --num-repeat 50 --n-epochs 1000 --n-hidden 512 --arch 1-0-1-0
python train_sampling.py --gpu 0 --dataset ppi --sampler edge --edge-budget 4000 --num-repeat 50 --n-epochs 1000 --n-hidden 512 --arch 1-0-1-0 --dropout 0.1
python train_sampling.py --gpu 0 --dataset ppi --sampler rw --num-roots 3000 --length 2 --num-repeat 50 --n-epochs 1000 --n-hidden 512 --arch 1-0-1-0 --dropout 0.1
python train_sampling.py --gpu 0 --dataset flickr --sampler node --node-budget 8000 --num-repeat 25 --n-epochs 30 --n-hidden 256 --arch 1-1-0 --dropout 0.2
python train_sampling.py --gpu 0 --dataset flickr --sampler edge --edge-budget 6000 --num-repeat 25 --n-epochs 15 --n-hidden 256 --arch 1-1-0 --dropout 0.2
python train_sampling.py --gpu 0 --dataset flickr --sampler rw --num-roots 6000 --length 2 --num-repeat 25 --n-epochs 15 --n-hidden 256 --arch 1-1-0 --dropout 0.2
```
## Comparison
* Paper: results from the paper
* Running: results from experiments with the authors' code
* DGL: results from experiments with the DGL example
### F1-micro
#### Random node sampler
| Method | PPI | Flickr |
| --- | --- | --- |
| Paper | 0.960±0.001 | 0.507±0.001 |
| Running | 0.9628 | 0.5077 |
| DGL | 0.9618 | 0.4828 |
#### Random edge sampler
| Method | PPI | Flickr |
| --- | --- | --- |
| Paper | 0.981±0.007 | 0.510±0.002 |
| Running | 0.9810 | 0.5066 |
| DGL | 0.9818 | 0.5054 |
#### Random walk sampler
| Method | PPI | Flickr |
| --- | --- | --- |
| Paper | 0.981±0.004 | 0.511±0.001 |
| Running | 0.9812 | 0.5104 |
| DGL | 0.9818 | 0.5018 |
### Sampling time
#### Random node sampler
| Method | PPI | Flickr |
| --- | --- | --- |
| Sampling(Running) | 0.77 | 0.65 |
| Sampling(DGL) | 0.24 | 0.57 |
| Normalization(Running) | 0.69 | 2.84 |
| Normalization(DGL) | 1.04 | 0.41 |
#### Random edge sampler
| Method | PPI | Flickr |
| --- | --- | --- |
| Sampling(Running) | 0.72 | 0.56 |
| Sampling(DGL) | 0.50 | 0.72 |
| Normalization(Running) | 0.68 | 2.62 |
| Normalization(DGL) | 0.61 | 0.38 |
#### Random walk sampler
| Method | PPI | Flickr |
| --- | --- | --- |
| Sampling(Running) | 0.83 | 1.22 |
| Sampling(DGL) | 0.28 | 0.63 |
| Normalization(Running) | 0.87 | 2.60 |
| Normalization(DGL) | 0.70 | 0.42 |
\ No newline at end of file
import torch.nn as nn
import torch.nn.functional as F
import torch as th
import dgl.function as fn
class GCNLayer(nn.Module):
def __init__(self, in_dim, out_dim, order=1, act=None,
dropout=0, batch_norm=False, aggr="concat"):
super(GCNLayer, self).__init__()
self.lins = nn.ModuleList()
self.bias = nn.ParameterList()
for _ in range(order + 1):
self.lins.append(nn.Linear(in_dim, out_dim, bias=False))
self.bias.append(nn.Parameter(th.zeros(out_dim)))
self.order = order
self.act = act
self.dropout = nn.Dropout(dropout)
self.batch_norm = batch_norm
if batch_norm:
self.offset, self.scale = nn.ParameterList(), nn.ParameterList()
for _ in range(order + 1):
self.offset.append(nn.Parameter(th.zeros(out_dim)))
self.scale.append(nn.Parameter(th.ones(out_dim)))
self.aggr = aggr
self.reset_parameters()
def reset_parameters(self):
for lin in self.lins:
nn.init.xavier_normal_(lin.weight)
def feat_trans(self, features, idx):
h = self.lins[idx](features) + self.bias[idx]
if self.act is not None:
h = self.act(h)
if self.batch_norm:
mean = h.mean(dim=1).view(h.shape[0], 1)
var = h.var(dim=1, unbiased=False).view(h.shape[0], 1) + 1e-9
h = (h - mean) * self.scale[idx] * th.rsqrt(var) + self.offset[idx]
return h
def forward(self, graph, features):
g = graph.local_var()
h_in = self.dropout(features)
h_hop = [h_in]
D_norm = g.ndata['train_D_norm'] if 'train_D_norm' in g.ndata else g.ndata['full_D_norm']
for _ in range(self.order):
g.ndata['h'] = h_hop[-1]
if 'w' not in g.edata:
g.edata['w'] = th.ones((g.num_edges(), )).to(features.device)
g.update_all(fn.u_mul_e('h', 'w', 'm'),
fn.sum('m', 'h'))
h = g.ndata.pop('h')
h = h * D_norm
h_hop.append(h)
h_part = [self.feat_trans(ft, idx) for idx, ft in enumerate(h_hop)]
if self.aggr == "mean":
h_out = h_part[0]
for i in range(len(h_part) - 1):
h_out = h_out + h_part[i + 1]
elif self.aggr == "concat":
h_out = th.cat(h_part, 1)
else:
raise NotImplementedError
return h_out
class GCNNet(nn.Module):
def __init__(self, in_dim, hid_dim, out_dim, arch="1-1-0",
act=F.relu, dropout=0, batch_norm=False, aggr="concat"):
super(GCNNet, self).__init__()
self.gcn = nn.ModuleList()
orders = list(map(int, arch.split('-')))
self.gcn.append(GCNLayer(in_dim=in_dim, out_dim=hid_dim, order=orders[0],
act=act, dropout=dropout, batch_norm=batch_norm, aggr=aggr))
pre_out = ((aggr == "concat") * orders[0] + 1) * hid_dim
for i in range(1, len(orders)-1):
self.gcn.append(GCNLayer(in_dim=pre_out, out_dim=hid_dim, order=orders[i],
act=act, dropout=dropout, batch_norm=batch_norm, aggr=aggr))
pre_out = ((aggr == "concat") * orders[i] + 1) * hid_dim
self.gcn.append(GCNLayer(in_dim=pre_out, out_dim=hid_dim, order=orders[-1],
act=act, dropout=dropout, batch_norm=batch_norm, aggr=aggr))
pre_out = ((aggr == "concat") * orders[-1] + 1) * hid_dim
self.out_layer = GCNLayer(in_dim=pre_out, out_dim=out_dim, order=0,
act=None, dropout=dropout, batch_norm=False, aggr=aggr)
def forward(self, graph):
h = graph.ndata['feat']
for layer in self.gcn:
h = layer(graph, h)
h = F.normalize(h, p=2, dim=1)
h = self.out_layer(graph, h)
return h
import math
import os
import time
import torch as th
import random
import numpy as np
import dgl.function as fn
import dgl
from dgl.sampling import random_walk, pack_traces
# The base class of sampler
# (TODO): online sampling
class SAINTSampler(object):
def __init__(self, dn, g, train_nid, node_budget, num_repeat=50):
"""
:param dn: name of dataset
:param g: full graph
:param train_nid: ids of training nodes
:param node_budget: expected number of sampled nodes
:param num_repeat: number of times of repeating sampling one node
"""
self.g = g
self.train_g: dgl.graph = g.subgraph(train_nid)
self.dn, self.num_repeat = dn, num_repeat
self.node_counter = th.zeros((self.train_g.num_nodes(),))
self.edge_counter = th.zeros((self.train_g.num_edges(),))
self.prob = None
graph_fn, norm_fn = self.__generate_fn__()
if os.path.exists(graph_fn):
self.subgraphs = np.load(graph_fn, allow_pickle=True)
aggr_norm, loss_norm = np.load(norm_fn, allow_pickle=True)
else:
os.makedirs('./subgraphs/', exist_ok=True)
self.subgraphs = []
self.N, sampled_nodes = 0, 0
t = time.perf_counter()
while sampled_nodes <= self.train_g.num_nodes() * num_repeat:
subgraph = self.__sample__()
self.subgraphs.append(subgraph)
sampled_nodes += subgraph.shape[0]
self.N += 1
print(f'Sampling time: [{time.perf_counter() - t:.2f}s]')
np.save(graph_fn, self.subgraphs)
t = time.perf_counter()
self.__counter__()
aggr_norm, loss_norm = self.__compute_norm__()
print(f'Normalization time: [{time.perf_counter() - t:.2f}s]')
np.save(norm_fn, (aggr_norm, loss_norm))
self.train_g.ndata['l_n'] = th.Tensor(loss_norm)
self.train_g.edata['w'] = th.Tensor(aggr_norm)
self.__compute_degree_norm()
self.num_batch = math.ceil(self.train_g.num_nodes() / node_budget)
random.shuffle(self.subgraphs)
self.__clear__()
print("The number of subgraphs is: ", len(self.subgraphs))
print("The size of subgraphs is about: ", len(self.subgraphs[-1]))
def __clear__(self):
self.prob = None
self.node_counter = None
self.edge_counter = None
self.g = None
def __counter__(self):
for sampled_nodes in self.subgraphs:
sampled_nodes = th.from_numpy(sampled_nodes)
self.node_counter[sampled_nodes] += 1
subg = self.train_g.subgraph(sampled_nodes)
sampled_edges = subg.edata[dgl.EID]
self.edge_counter[sampled_edges] += 1
def __generate_fn__(self):
raise NotImplementedError
def __compute_norm__(self):
self.node_counter[self.node_counter == 0] = 1
self.edge_counter[self.edge_counter == 0] = 1
loss_norm = self.N / self.node_counter / self.train_g.num_nodes()
self.train_g.ndata['n_c'] = self.node_counter
self.train_g.edata['e_c'] = self.edge_counter
self.train_g.apply_edges(fn.v_div_e('n_c', 'e_c', 'a_n'))
aggr_norm = self.train_g.edata.pop('a_n')
self.train_g.ndata.pop('n_c')
self.train_g.edata.pop('e_c')
return aggr_norm.numpy(), loss_norm.numpy()
def __compute_degree_norm(self):
self.train_g.ndata['train_D_norm'] = 1. / self.train_g.in_degrees().float().clamp(min=1).unsqueeze(1)
self.g.ndata['full_D_norm'] = 1. / self.g.in_degrees().float().clamp(min=1).unsqueeze(1)
def __sample__(self):
raise NotImplementedError
def __len__(self):
return self.num_batch
def __iter__(self):
self.n = 0
return self
def __next__(self):
if self.n < self.num_batch:
result = self.train_g.subgraph(self.subgraphs[self.n])
self.n += 1
return result
else:
random.shuffle(self.subgraphs)
raise StopIteration()
class SAINTNodeSampler(SAINTSampler):
def __init__(self, node_budget, dn, g, train_nid, num_repeat=50):
self.node_budget = node_budget
super(SAINTNodeSampler, self).__init__(dn, g, train_nid, node_budget, num_repeat)
def __generate_fn__(self):
graph_fn = os.path.join('./subgraphs/{}_Node_{}_{}.npy'.format(self.dn, self.node_budget,
self.num_repeat))
norm_fn = os.path.join('./subgraphs/{}_Node_{}_{}_norm.npy'.format(self.dn, self.node_budget,
self.num_repeat))
return graph_fn, norm_fn
def __sample__(self):
if self.prob is None:
self.prob = self.train_g.in_degrees().float().clamp(min=1)
sampled_nodes = th.multinomial(self.prob, num_samples=self.node_budget, replacement=True).unique()
return sampled_nodes.numpy()
class SAINTEdgeSampler(SAINTSampler):
def __init__(self, edge_budget, dn, g, train_nid, num_repeat=50):
self.edge_budget = edge_budget
super(SAINTEdgeSampler, self).__init__(dn, g, train_nid, edge_budget * 2, num_repeat)
def __generate_fn__(self):
graph_fn = os.path.join('./subgraphs/{}_Edge_{}_{}.npy'.format(self.dn, self.edge_budget,
self.num_repeat))
norm_fn = os.path.join('./subgraphs/{}_Edge_{}_{}_norm.npy'.format(self.dn, self.edge_budget,
self.num_repeat))
return graph_fn, norm_fn
def __sample__(self):
if self.prob is None:
src, dst = self.train_g.edges()
src_degrees, dst_degrees = self.train_g.in_degrees(src).float().clamp(min=1),\
self.train_g.in_degrees(dst).float().clamp(min=1)
self.prob = 1. / src_degrees + 1. / dst_degrees
sampled_edges = th.multinomial(self.prob, num_samples=self.edge_budget, replacement=True).unique()
sampled_src, sampled_dst = self.train_g.find_edges(sampled_edges)
sampled_nodes = th.cat([sampled_src, sampled_dst]).unique()
return sampled_nodes.numpy()
class SAINTRandomWalkSampler(SAINTSampler):
def __init__(self, num_roots, length, dn, g, train_nid, num_repeat=50):
self.num_roots, self.length = num_roots, length
super(SAINTRandomWalkSampler, self).__init__(dn, g, train_nid, num_roots * length, num_repeat)
def __generate_fn__(self):
graph_fn = os.path.join('./subgraphs/{}_RW_{}_{}_{}.npy'.format(self.dn, self.num_roots,
self.length, self.num_repeat))
norm_fn = os.path.join('./subgraphs/{}_RW_{}_{}_{}_norm.npy'.format(self.dn, self.num_roots,
self.length, self.num_repeat))
return graph_fn, norm_fn
def __sample__(self):
sampled_roots = th.randint(0, self.train_g.num_nodes(), (self.num_roots, ))
traces, types = random_walk(self.train_g, nodes=sampled_roots, length=self.length)
sampled_nodes, _, _, _ = pack_traces(traces, types)
sampled_nodes = sampled_nodes.unique()
return sampled_nodes.numpy()
import argparse
import os
import time
import numpy as np
import torch
import torch.nn.functional as F
from sampler import SAINTNodeSampler, SAINTEdgeSampler, SAINTRandomWalkSampler
from modules import GCNNet
from utils import Logger, evaluate, save_log_dir, load_data
def main(args):
multilabel_data = set(['ppi'])
multilabel = args.dataset in multilabel_data
# load and preprocess dataset
data = load_data(args, multilabel)
g = data.g
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
labels = g.ndata['label']
train_nid = data.train_nid
in_feats = g.ndata['feat'].shape[1]
n_classes = data.num_classes
n_nodes = g.num_nodes()
n_edges = g.num_edges()
n_train_samples = train_mask.int().sum().item()
n_val_samples = val_mask.int().sum().item()
n_test_samples = test_mask.int().sum().item()
print("""----Data statistics------'
#Nodes %d
#Edges %d
#Classes/Labels (multi binary labels) %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_nodes, n_edges, n_classes,
n_train_samples,
n_val_samples,
n_test_samples))
# load sampler
if args.sampler == "node":
subg_iter = SAINTNodeSampler(args.node_budget, args.dataset, g,
train_nid, args.num_repeat)
elif args.sampler == "edge":
subg_iter = SAINTEdgeSampler(args.edge_budget, args.dataset, g,
train_nid, args.num_repeat)
elif args.sampler == "rw":
subg_iter = SAINTRandomWalkSampler(args.num_roots, args.length, args.dataset, g,
train_nid, args.num_repeat)
# set device for dataset tensors
if args.gpu < 0:
cuda = False
else:
cuda = True
torch.cuda.set_device(args.gpu)
val_mask = val_mask.cuda()
test_mask = test_mask.cuda()
g = g.to(args.gpu)
print('labels shape:', g.ndata['label'].shape)
print("features shape:", g.ndata['feat'].shape)
model = GCNNet(
in_dim=in_feats,
hid_dim=args.n_hidden,
out_dim=n_classes,
arch=args.arch,
dropout=args.dropout,
batch_norm=not args.no_batch_norm,
aggr=args.aggr
)
if cuda:
model.cuda()
# logger and so on
log_dir = save_log_dir(args)
logger = Logger(os.path.join(log_dir, 'loggings'))
logger.write(args)
# use optimizer
optimizer = torch.optim.Adam(model.parameters(),
lr=args.lr)
# set train_nids to cuda tensor
if cuda:
train_nid = torch.from_numpy(train_nid).cuda()
print("GPU memory allocated before training(MB)",
torch.cuda.memory_allocated(device=train_nid.device) / 1024 / 1024)
start_time = time.time()
best_f1 = -1
for epoch in range(args.n_epochs):
for j, subg in enumerate(subg_iter):
# sync with upper level training graph
if cuda:
subg = subg.to(torch.cuda.current_device())
model.train()
# forward
pred = model(subg)
batch_labels = subg.ndata['label']
if multilabel:
loss = F.binary_cross_entropy_with_logits(pred, batch_labels, reduction='sum',
weight=subg.ndata['l_n'].unsqueeze(1))
else:
loss = F.cross_entropy(pred, batch_labels, reduction='none')
loss = (subg.ndata['l_n'] * loss).sum()
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm(model.parameters(), 5)
optimizer.step()
if j == len(subg_iter) - 1:
print(f"epoch:{epoch+1}/{args.n_epochs}, Iteration {j+1}/"
f"{len(subg_iter)}:training loss", loss.item())
# evaluate
if epoch % args.val_every == 0:
val_f1_mic, val_f1_mac = evaluate(
model, g, labels, val_mask, multilabel)
print(
"Val F1-mic {:.4f}, Val F1-mac {:.4f}".format(val_f1_mic, val_f1_mac))
if val_f1_mic > best_f1:
best_f1 = val_f1_mic
print('new best val f1:', best_f1)
torch.save(model.state_dict(), os.path.join(
log_dir, 'best_model.pkl'))
end_time = time.time()
print(f'training using time {end_time - start_time}')
# test
if args.use_val:
model.load_state_dict(torch.load(os.path.join(
log_dir, 'best_model.pkl')))
test_f1_mic, test_f1_mac = evaluate(
model, g, labels, test_mask, multilabel)
print("Test F1-mic {:.4f}, Test F1-mac {:.4f}".format(test_f1_mic, test_f1_mac))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GraphSAINT')
# data source params
parser.add_argument("--dataset", type=str, choices=['ppi', 'flickr'], default='ppi',
help="Name of dataset.")
# cuda params
parser.add_argument("--gpu", type=int, default=-1,
help="GPU index. Default: -1, using CPU.")
# sampler params
parser.add_argument("--sampler", type=str, default="node", choices=['node', 'edge', 'rw'],
help="Type of sampler")
parser.add_argument("--node-budget", type=int, default=6000,
help="Expected number of sampled nodes when using node sampler")
parser.add_argument("--edge-budget", type=int, default=4000,
help="Expected number of sampled edges when using edge sampler")
parser.add_argument("--num-roots", type=int, default=3000,
help="Expected number of sampled root nodes when using random walk sampler")
parser.add_argument("--length", type=int, default=2,
help="The length of random walk when using random walk sampler")
parser.add_argument("--num-repeat", type=int, default=50,
help="Number of times of repeating sampling one node to estimate edge / node probability")
# model params
parser.add_argument("--n-hidden", type=int, default=512,
help="Number of hidden gcn units")
parser.add_argument("--arch", type=str, default="1-0-1-0",
help="Network architecture. 1 means an order-1 layer (self feature plus 1-hop neighbor "
"feature), and 0 means an order-0 layer (self feature only)")
parser.add_argument("--dropout", type=float, default=0,
help="Dropout rate")
parser.add_argument("--no-batch-norm", action='store_true',
help="Whether to use batch norm")
parser.add_argument("--aggr", type=str, default="concat", choices=['mean', 'concat'],
help="How to aggregate the self feature and neighbor features")
# training params
parser.add_argument("--n-epochs", type=int, default=100,
help="Number of training epochs")
parser.add_argument("--lr", type=float, default=0.01,
help="Learning rate")
parser.add_argument("--val-every", type=int, default=1,
help="Frequency of evaluation on the validation set in number of epochs")
parser.add_argument("--use-val", action='store_true',
help="whether to use validated best model to test")
parser.add_argument("--log-dir", type=str, default='none',
help="Log file will be saved to log/{dataset}/{log_dir}")
args = parser.parse_args()
print(args)
main(args)
import json
import os
from functools import namedtuple
import scipy.sparse
from sklearn.preprocessing import StandardScaler
import dgl
import numpy as np
import torch
from sklearn.metrics import f1_score
class Logger(object):
'''A custom logger to log stdout to a logging file.'''
def __init__(self, path):
"""Initialize the logger.
Parameters
---------
path : str
The file path to be stored in.
"""
self.path = path
def write(self, s):
with open(self.path, 'a') as f:
f.write(str(s))
print(s)
return
def save_log_dir(args):
log_dir = './log/{}/{}'.format(args.dataset, args.log_dir)
os.makedirs(log_dir, exist_ok=True)
return log_dir
def calc_f1(y_true, y_pred, multilabel):
if multilabel:
y_pred[y_pred > 0] = 1
y_pred[y_pred <= 0] = 0
else:
y_pred = np.argmax(y_pred, axis=1)
return f1_score(y_true, y_pred, average="micro"), \
f1_score(y_true, y_pred, average="macro")
def evaluate(model, g, labels, mask, multilabel=False):
model.eval()
with torch.no_grad():
logits = model(g)
logits = logits[mask]
labels = labels[mask]
f1_mic, f1_mac = calc_f1(labels.cpu().numpy(),
logits.cpu().numpy(), multilabel)
return f1_mic, f1_mac
# load data of GraphSAINT and convert them to the format of dgl
def load_data(args, multilabel):
prefix = "data/{}".format(args.dataset)
DataType = namedtuple('Dataset', ['num_classes', 'train_nid', 'g'])
adj_full = scipy.sparse.load_npz('./{}/adj_full.npz'.format(prefix)).astype(np.bool)
g = dgl.from_scipy(adj_full)
num_nodes = g.num_nodes()
adj_train = scipy.sparse.load_npz('./{}/adj_train.npz'.format(prefix)).astype(np.bool)
train_nid = np.array(list(set(adj_train.nonzero()[0])))
role = json.load(open('./{}/role.json'.format(prefix)))
mask = np.zeros((num_nodes,), dtype=bool)
train_mask = mask.copy()
train_mask[role['tr']] = True
val_mask = mask.copy()
val_mask[role['va']] = True
test_mask = mask.copy()
test_mask[role['te']] = True
feats = np.load('./{}/feats.npy'.format(prefix))
scaler = StandardScaler()
scaler.fit(feats[train_nid])
feats = scaler.transform(feats)
class_map = json.load(open('./{}/class_map.json'.format(prefix)))
class_map = {int(k): v for k, v in class_map.items()}
if multilabel:
# Multi-label binary classification
num_classes = len(list(class_map.values())[0])
class_arr = np.zeros((num_nodes, num_classes))
for k, v in class_map.items():
class_arr[k] = v
else:
num_classes = max(class_map.values()) - min(class_map.values()) + 1
class_arr = np.zeros((num_nodes,))
for k, v in class_map.items():
class_arr[k] = v
g.ndata['feat'] = torch.tensor(feats, dtype=torch.float)
g.ndata['label'] = torch.tensor(class_arr, dtype=torch.float if multilabel else torch.long)
g.ndata['train_mask'] = torch.tensor(train_mask, dtype=torch.bool)
g.ndata['val_mask'] = torch.tensor(val_mask, dtype=torch.bool)
g.ndata['test_mask'] = torch.tensor(test_mask, dtype=torch.bool)
data = DataType(g=g, num_classes=num_classes, train_nid=train_nid)
return data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment