Commit a3febc06 authored by kitaev-chen's avatar kitaev-chen Committed by VoVAllen
Browse files

[Model] Add GIN Model (#471)

* add gin model

* convert dataset.py to data_ont_the_fly way and put it into dgl.data module

* convert dataset.py to data_ont_the_fly way and put it into dgl.data module
python code checked

* modified document and reference TUDataset; checked python part and bypass cpp part due to error

* change tensor to numpy in dataset and transform in collate@Dataloader

* Change minor format issue

Change minor format issue

* moved logging; adjusted tqdm etc
parent fb6af16f
# IDE
.idea
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
......
......@@ -15,6 +15,7 @@ Contributors
* [Yizhi Liu](https://github.com/yzhliu): RGCN in MXNet
* [@hbsun2113](https://github.com/hbsun2113): GraphSAGE in Pytorch
* [Tianyi Zhang](https://github.com/Tiiiger): SGC in Pytorch
* [Jun Chen](https://github.com/kitaev-chen): GIN in Pytorch
Other improvement
* [Brett Koonce](https://github.com/brettkoonce)
......
......@@ -33,6 +33,7 @@ Mini graph classification dataset
.. autoclass:: MiniGCDataset
:members: __getitem__, __len__, num_classes
Graph kernel dataset
````````````````````
......@@ -41,6 +42,16 @@ For more information about the dataset, see `Benchmark Data Sets for Graph Kerne
.. autoclass:: TUDataset
:members: __getitem__, __len__
Graph isomorphism network dataset
```````````````````````````````````
A compact subset of graph kernel dataset
.. autoclass:: GINDataset
:members: __getitem__, __len__
Protein-Protein Interaction dataset
```````````````````````````````````
......
Graph Isomorphism Network (GIN)
============
- Paper link: [arXiv](https://arxiv.org/abs/1810.00826) [OpenReview](https://openreview.net/forum?id=ryGs6iA5Km)
- Author's code repo: [https://github.com/weihua916/powerful-gnns](https://github.com/weihua916/powerful-gnns).
Dependencies
------------
- PyTorch 1.0.1+
- sklearn
- tqdm
``bash
pip install torch sklearn tqdm
``
How to run
----------
An experiment on the GIN in default settings can be run with
```bash
python main.py
```
An experiment on the GIN in customized settings can be run with
```bash
python main.py [--device 0 | --disable-cuda] --dataset COLLAB \
--graph_pooling_type max --neighbor_pooling_type sum
```
Results
-------
Run with following with the double SUM pooling way:
(tested dataset: "MUTAG"(default), "COLLAB", "IMDBBINARY", "IMDBMULTI")
```bash
python train.py --dataset MUTAB --device 0 \
--graph_pooling_type sum --neighbor_pooling_type sum
```
* MUTAG: 0.85 (paper: ~0.89)
* COLLAB: 0.89 (paper: ~0.80)
* IMDBBINARY: 0.76 (paper: ~0.75)
* IMDBMULTI: 0.51 (paper: ~0.52)
"""
PyTorch compatible dataloader
"""
import math
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from sklearn.model_selection import StratifiedKFold
import dgl
# default collate function
def collate(samples):
# The input `samples` is a list of pairs (graph, label).
graphs, labels = map(list, zip(*samples))
for g in graphs:
# deal with node feats
for feat in g.node_attr_schemes().keys():
# TODO torch.Tensor is not recommended
# torch.DoubleTensor and torch.tensor
# will meet error in executor.py@runtime line 472, tensor.py@backend line 147
# RuntimeError: expected type torch.cuda.DoubleTensor but got torch.cuda.FloatTensor
g.ndata[feat] = torch.Tensor(g.ndata[feat])
# no edge feats
batched_graph = dgl.batch(graphs)
labels = torch.tensor(labels)
return batched_graph, labels
class GraphDataLoader():
def __init__(self,
dataset,
batch_size,
device,
collate_fn=collate,
seed=0,
shuffle=True,
split_name='fold10',
fold_idx=0,
split_ratio=0.7):
self.shuffle = shuffle
self.seed = seed
self.kwargs = {'pin_memory': True} if 'cuda' in device.type else {}
labels = [l for _, l in dataset]
if split_name == 'fold10':
train_idx, valid_idx = self._split_fold10(
labels, fold_idx, seed, shuffle)
elif split_name == 'rand':
train_idx, valid_idx = self._split_rand(
labels, split_ratio, seed, shuffle)
else:
raise NotImplementedError()
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
self.train_loader = DataLoader(
dataset, sampler=train_sampler,
batch_size=batch_size, collate_fn=collate, **self.kwargs)
self.valid_loader = DataLoader(
dataset, sampler=valid_sampler,
batch_size=batch_size, collate_fn=collate, **self.kwargs)
def train_valid_loader(self):
return self.train_loader, self.valid_loader
def _split_fold10(self, labels, fold_idx=0, seed=0, shuffle=True):
''' 10 flod '''
assert 0 <= fold_idx and fold_idx < 10, print(
"fold_idx must be from 0 to 9.")
idx_list = []
skf = StratifiedKFold(n_splits=10, shuffle=shuffle, random_state=seed)
idx_list = []
for idx in skf.split(np.zeros(len(labels)), labels): # split(x, y)
idx_list.append(idx)
train_idx, valid_idx = idx_list[fold_idx]
print(
"train_set : test_set = %d : %d",
len(train_idx), len(valid_idx))
return train_idx, valid_idx
def _split_rand(self, labels, split_ratio=0.7, seed=0, shuffle=True):
num_entries = len(labels)
indices = list(range(num_entries))
np.random.seed(seed)
np.random.shuffle(indices)
split = int(math.floor(split_ratio * num_entries))
train_idx, valid_idx = indices[:split], indices[split:]
print(
"train_set : test_set = %d : %d",
len(train_idx), len(valid_idx))
return train_idx, valid_idx
"""
How Powerful are Graph Neural Networks
https://arxiv.org/abs/1810.00826
https://openreview.net/forum?id=ryGs6iA5Km
Author's implementation: https://github.com/weihua916/powerful-gnns
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
# Sends a message of node feature h.
msg = fn.copy_src(src='h', out='m')
reduce_sum = fn.sum(msg='m', out='h')
reduce_max = fn.max(msg='m', out='h')
def reduce_mean(nodes):
return {'h': torch.mean(nodes.mailbox['m'], dim=1)[0]}
class ApplyNodes(nn.Module):
"""Update the node feature hv with MLP, BN and ReLU."""
def __init__(self, mlp, layer):
super(ApplyNodes, self).__init__()
self.mlp = mlp
self.bn = nn.BatchNorm1d(self.mlp.output_dim)
self.layer = layer
def forward(self, nodes):
h = self.mlp(nodes.data['h'])
h = self.bn(h)
h = F.relu(h)
return {'h': h}
class GINLayer(nn.Module):
"""Neighbor pooling and reweight nodes before send graph into MLP"""
def __init__(self, eps, layer, mlp, neighbor_pooling_type, learn_eps):
super(GINLayer, self).__init__()
self.bn = nn.BatchNorm1d(mlp.output_dim)
self.neighbor_pooling_type = neighbor_pooling_type
self.eps = eps
self.learn_eps = learn_eps
self.layer = layer
self.apply_mod = ApplyNodes(mlp, layer)
def forward(self, g, feature):
g.ndata['h'] = feature
if self.neighbor_pooling_type == 'sum':
reduce_func = reduce_sum
elif self.neighbor_pooling_type == 'mean':
reduce_func = reduce_mean
elif self.neighbor_pooling_type == 'max':
reduce_func = reduce_max
else:
raise NotImplementedError()
h = feature # h0
g.update_all(msg, reduce_func)
pooled = g.ndata['h']
# reweight the center node when aggregating it with its neighbors
if self.learn_eps:
pooled = pooled + (1 + self.eps[self.layer])*h
g.ndata['h'] = pooled
g.apply_nodes(func=self.apply_mod)
return g.ndata.pop('h')
class MLP(nn.Module):
"""MLP with linear output"""
def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
"""MLP layers construction
Paramters
---------
num_layers: int
The number of linear layers
input_dim: int
The dimensionality of input features
hidden_dim: int
The dimensionality of hidden units at ALL layers
output_dim: int
The number of classes for prediction
"""
super(MLP, self).__init__()
self.linear_or_not = True # default is linear model
self.num_layers = num_layers
self.output_dim = output_dim
if num_layers < 1:
raise ValueError("number of layers should be positive!")
elif num_layers == 1:
# Linear model
self.linear = nn.Linear(input_dim, output_dim)
else:
# Multi-layer model
self.linear_or_not = False
self.linears = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
self.linears.append(nn.Linear(input_dim, hidden_dim))
for layer in range(num_layers - 2):
self.linears.append(nn.Linear(hidden_dim, hidden_dim))
self.linears.append(nn.Linear(hidden_dim, output_dim))
for layer in range(num_layers - 1):
self.batch_norms.append(nn.BatchNorm1d((hidden_dim)))
def forward(self, x):
if self.linear_or_not:
# If linear model
return self.linear(x)
else:
# If MLP
h = x
for layer in range(self.num_layers - 1):
h = F.relu(self.batch_norms[layer](self.linears[layer](h)))
return self.linears[self.num_layers - 1](h)
class GIN(nn.Module):
"""GIN model"""
def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim,
output_dim, final_dropout, learn_eps, graph_pooling_type,
neighbor_pooling_type, device):
"""model parameters setting
Paramters
---------
num_layers: int
The number of linear layers in the neural network
num_mlp_layers: int
The number of linear layers in mlps
input_dim: int
The dimensionality of input features
hidden_dim: int
The dimensionality of hidden units at ALL layers
output_dim: int
The number of classes for prediction
final_dropout: float
dropout ratio on the final linear layer
learn_eps: boolean
If True, learn epsilon to distinguish center nodes from neighbors
If False, aggregate neighbors and center nodes altogether.
neighbor_pooling_type: str
how to aggregate neighbors (sum, mean, or max)
graph_pooling_type: str
how to aggregate entire nodes in a graph (sum, mean or max)
device: str
which device to use
"""
super(GIN, self).__init__()
self.final_dropout = final_dropout
self.device = device
self.num_layers = num_layers
self.graph_pooling_type = graph_pooling_type
self.learn_eps = learn_eps
self.eps = nn.Parameter(torch.zeros(self.num_layers - 1))
# List of MLPs
self.ginlayers = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
for layer in range(self.num_layers - 1):
if layer == 0:
mlp = MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim)
else:
mlp = MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim)
self.ginlayers.append(GINLayer(
self.eps, layer, mlp, neighbor_pooling_type, self.learn_eps))
self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
# Linear function for graph poolings of output of each layer
# which maps the output of different layers into a prediction score
self.linears_prediction = torch.nn.ModuleList()
for layer in range(num_layers):
if layer == 0:
self.linears_prediction.append(
nn.Linear(input_dim, output_dim))
else:
self.linears_prediction.append(
nn.Linear(hidden_dim, output_dim))
def forward(self, g):
h = g.ndata['attr']
h = h.to(self.device)
# list of hidden representation at each layer (including input)
hidden_rep = [h]
for layer in range(self.num_layers - 1):
h = self.ginlayers[layer](g, h)
hidden_rep.append(h)
score_over_layer = 0
# perform pooling over all nodes in each graph in every layer
for layer, h in enumerate(hidden_rep):
g.ndata['h'] = h
if self.graph_pooling_type == 'sum':
pooled_h = dgl.sum_nodes(g, 'h')
elif self.graph_pooling_type == 'mean':
pooled_h = dgl.mean_nodes(g, 'h')
elif self.graph_pooling_type == 'max':
pooled_h = dgl.max_nodes(g, 'h')
else:
raise NotImplementedError()
score_over_layer += F.dropout(
self.linears_prediction[layer](pooled_h),
self.final_dropout,
training=self.training)
return score_over_layer
import sys
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from dgl.data.gindt import GINDataset
from dataloader import GraphDataLoader, collate
from parser import Parser
from gin import GIN
def train(args, net, trainloader, optimizer, criterion, epoch):
net.train()
running_loss = 0
total_iters = len(trainloader)
# setup the offset to avoid the overlap with mouse cursor
bar = tqdm(range(total_iters), unit='batch', position=2, file=sys.stdout)
for pos, (graphs, labels) in zip(bar, trainloader):
# batch graphs will be shipped to device in forward part of model
labels = labels.to(args.device)
outputs = net(graphs)
loss = criterion(outputs, labels)
running_loss += loss.item()
# backprop
if optimizer is not None:
optimizer.zero_grad()
loss.backward()
optimizer.step()
# report
bar.set_description('epoch-{}'.format(epoch))
bar.close()
# the final batch will be aligned
running_loss = running_loss / total_iters
return running_loss
def eval_net(args, net, dataloader, criterion):
net.eval()
total = 0
total_loss = 0
total_correct = 0
# total_iters = len(dataloader)
for data in dataloader:
graphs, labels = data
labels = labels.to(args.device)
total += len(labels)
outputs = net(graphs)
_, predicted = torch.max(outputs.data, 1)
total_correct += (predicted == labels.data).sum().item()
loss = criterion(outputs, labels)
# crossentropy(reduce=True) for default
total_loss += loss.item() * len(labels)
loss, acc = 1.0*total_loss / total, 1.0*total_correct / total
net.train()
return loss, acc
def main(args):
# set up seeds, args.seed supported
torch.manual_seed(seed=0)
np.random.seed(seed=0)
is_cuda = not args.disable_cuda and torch.cuda.is_available()
if is_cuda:
args.device = torch.device("cuda:" + str(args.device))
torch.cuda.manual_seed_all(seed=0)
else:
args.device = torch.device("cpu")
dataset = GINDataset(args.dataset, not args.learn_eps)
trainloader, validloader = GraphDataLoader(
dataset, batch_size=args.batch_size, device=args.device,
collate_fn=collate, seed=args.seed, shuffle=True,
split_name='fold10', fold_idx=args.fold_idx).train_valid_loader()
# or split_name='rand', split_ratio=0.7
model = GIN(
args.num_layers, args.num_mlp_layers,
dataset.dim_nfeats, args.hidden_dim, dataset.gclasses,
args.final_dropout, args.learn_eps,
args.graph_pooling_type, args.neighbor_pooling_type,
args.device).to(args.device)
criterion = nn.CrossEntropyLoss() # defaul reduce is true
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
# it's not cost-effective to hanle the cursor and init 0
# https://stackoverflow.com/a/23121189
tbar = tqdm(range(args.epochs), unit="epoch", position=3, ncols=0, file=sys.stdout)
vbar = tqdm(range(args.epochs), unit="epoch", position=4, ncols=0, file=sys.stdout)
lrbar = tqdm(range(args.epochs), unit="epoch", position=5, ncols=0, file=sys.stdout)
for epoch, _, _ in zip(tbar, vbar, lrbar):
scheduler.step()
train(args, model, trainloader, optimizer, criterion, epoch)
train_loss, train_acc = eval_net(
args, model, trainloader, criterion)
tbar.set_description(
'train set - average loss: {:.4f}, accuracy: {:.0f}%'
.format(train_loss, 100. * train_acc))
valid_loss, valid_acc = eval_net(
args, model, validloader, criterion)
vbar.set_description(
'valid set - average loss: {:.4f}, accuracy: {:.0f}%'
.format(valid_loss, 100. * valid_acc))
if not args.filename == "":
with open(args.filename, 'a') as f:
f.write('%s %s %s %s' % (
args.dataset,
args.learn_eps,
args.neighbor_pooling_type,
args.graph_pooling_type
))
f.write("\n")
f.write("%f %f %f %f" % (
train_loss,
train_acc,
valid_loss,
valid_acc
))
f.write("\n")
lrbar.set_description(
"the learning eps with learn_eps={} is: {}".format(
args.learn_eps, model.eps.data))
tbar.close()
vbar.close()
lrbar.close()
if __name__ == '__main__':
args = Parser(description='GIN').args
print('show all arguments configuration...')
print(args)
main(args)
"""Parser for arguments
Put all arguments in one file and group similar arguments
"""
import argparse
class Parser():
def __init__(self, description):
'''
arguments parser
'''
self.parser = argparse.ArgumentParser(description=description)
self.args = None
self._parse()
def _parse(self):
# dataset
self.parser.add_argument(
'--dataset', type=str, default="MUTAG",
help='name of dataset (default: MUTAG)')
self.parser.add_argument(
'--batch_size', type=int, default=32,
help='batch size for training and validation (default: 32)')
self.parser.add_argument(
'--fold_idx', type=int, default=0,
help='the index(<10) of fold in 10-fold validation.')
self.parser.add_argument(
'--filename', type=str, default="",
help='output file')
# device
self.parser.add_argument(
'--disable-cuda', action='store_true',
help='Disable CUDA')
self.parser.add_argument(
'--device', type=int, default=0,
help='which gpu device to use (default: 0)')
# net
self.parser.add_argument(
'--net', type=str, default="gin",
help='gnn net (default: gin)')
self.parser.add_argument(
'--num_layers', type=int, default=5,
help='number of layers (default: 5)')
self.parser.add_argument(
'--num_mlp_layers', type=int, default=2,
help='number of MLP layers(default: 2). 1 means linear model.')
self.parser.add_argument(
'--hidden_dim', type=int, default=64,
help='number of hidden units (default: 64)')
# graph
self.parser.add_argument(
'--graph_pooling_type', type=str,
default="sum", choices=["sum", "mean", "max"],
help='type of graph pooling: sum, mean or max')
self.parser.add_argument(
'--neighbor_pooling_type', type=str,
default="sum", choices=["sum", "mean", "max"],
help='type of neighboring pooling: sum, mean or max')
self.parser.add_argument(
'--learn_eps', action="store_true",
help='learn the epsilon weighting')
self.parser.add_argument(
'--degree_as_tag', action="store_true",
help='take the degree of nodes as input feature')
# learning
self.parser.add_argument(
'--seed', type=int, default=0,
help='random seed (default: 0)')
self.parser.add_argument(
'--epochs', type=int, default=350,
help='number of epochs to train (default: 350)')
self.parser.add_argument(
'--lr', type=float, default=0.01,
help='learning rate (default: 0.01)')
self.parser.add_argument(
'--final_dropout', type=float, default=0.5,
help='final layer dropout (default: 0.5)')
# done
self.args = self.parser.parse_args()
......@@ -10,6 +10,8 @@ from .sbm import SBMMixture
from .reddit import RedditDataset
from .ppi import PPIDataset
from .tu import TUDataset
from .gindt import GINDataset
def register_data_args(parser):
parser.add_argument("--dataset", type=str, required=False,
......
"""Dataset for Graph Isomorphism Network(GIN)
(chen jun): Used for compacted graph kernel dataset in GIN
Data sets include:
MUTAG, COLLAB, IMDBBINARY, IMDBMULTI, NCI1, PROTEINS, PTC, REDDITBINARY, REDDITMULTI5K
https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip
"""
import os
import numpy as np
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..graph import DGLGraph
_url = 'https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip'
class GINDataset(object):
"""Datasets for Graph Isomorphism Network (GIN)
Adapted from https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip.
The dataset contains the compact format of popular graph kernel datasets, which includes:
MUTAG, COLLAB, IMDBBINARY, IMDBMULTI, NCI1, PROTEINS, PTC, REDDITBINARY, REDDITMULTI5K
This datset class processes all data sets listed above. For more graph kernel datasets,
see :class:`TUDataset`
Paramters
---------
name: str
dataset name, one of below -
('MUTAG', 'COLLAB', \
'IMDBBINARY', 'IMDBMULTI', \
'NCI1', 'PROTEINS', 'PTC', \
'REDDITBINARY', 'REDDITMULTI5K')
self_loop: boolean
add self to self edge if true
degree_as_nlabel: boolean
take node degree as label and feature if true
"""
def __init__(self, name, self_loop, degree_as_nlabel=False):
"""Initialize the dataset."""
self.name = name # MUTAG
self.ds_name = 'nig'
self.extract_dir = self._download()
self.file = self._file_path()
self.self_loop = self_loop
self.graphs = []
self.labels = []
# relabel
self.glabel_dict = {}
self.nlabel_dict = {}
self.elabel_dict = {}
self.ndegree_dict = {}
# global num
self.N = 0 # total graphs number
self.n = 0 # total nodes number
self.m = 0 # total edges number
# global num of classes
self.gclasses = 0
self.nclasses = 0
self.eclasses = 0
self.dim_nfeats = 0
# flags
self.degree_as_nlabel = degree_as_nlabel
self.nattrs_flag = False
self.nlabels_flag = False
self.verbosity = False
# calc all values
self._load()
def __len__(self):
"""Return the number of graphs in the dataset."""
return len(self.graphs)
def __getitem__(self, idx):
"""Get the i^th sample.
Paramters
---------
idx : int
The sample index.
Returns
-------
(dgl.DGLGraph, int)
The graph and its label.
"""
return self.graphs[idx], self.labels[idx]
def _download(self):
download_dir = get_download_dir()
zip_file_path = os.path.join(
download_dir, "{}.zip".format(self.ds_name))
# TODO move to dgl host _get_dgl_url
download(_url, path=zip_file_path)
extract_dir = os.path.join(
download_dir, "{}".format(self.ds_name))
extract_archive(zip_file_path, extract_dir)
return extract_dir
def _file_path(self):
return os.path.join(self.extract_dir, "dataset", self.name, "{}.txt".format(self.name))
def _load(self):
""" Loads input dataset from dataset/NAME/NAME.txt file
"""
print('loading data...')
with open(self.file, 'r') as f:
# line_1 == N, total number of graphs
self.N = int(f.readline().strip())
for i in range(self.N):
if (i + 1) % 10 == 0 and self.verbosity is True:
print('processing graph {}...'.format(i + 1))
grow = f.readline().strip().split()
# line_2 == [n_nodes, l] is equal to
# [node number of a graph, class label of a graph]
n_nodes, glabel = [int(w) for w in grow]
# relabel graphs
if glabel not in self.glabel_dict:
mapped = len(self.glabel_dict)
self.glabel_dict[glabel] = mapped
self.labels.append(self.glabel_dict[glabel])
g = DGLGraph()
g.add_nodes(n_nodes)
nlabels = [] # node labels
nattrs = [] # node attributes if it has
m_edges = 0
for j in range(n_nodes):
nrow = f.readline().strip().split()
# handle edges and attributes(if has)
tmp = int(nrow[1]) + 2 # tmp == 2 + #edges
if tmp == len(nrow):
# no node attributes
nrow = [int(w) for w in nrow]
nattr = None
elif tmp > len(nrow):
nrow = [int(w) for w in nrow[:tmp]]
nattr = [float(w) for w in nrow[tmp:]]
nattrs.append(nattr)
else:
raise Exception('edge number is incorrect!')
# relabel nodes if it has labels
# if it doesn't have node labels, then every nrow[0]==0
if not nrow[0] in self.nlabel_dict:
mapped = len(self.nlabel_dict)
self.nlabel_dict[nrow[0]] = mapped
nlabels.append(self.nlabel_dict[nrow[0]])
m_edges += nrow[1]
g.add_edges(j, nrow[2:])
# add self loop
if self.self_loop:
m_edges += 1
g.add_edge(j, j)
if (j + 1) % 10 == 0 and self.verbosity is True:
print(
'processing node {} of graph {}...'.format(
j + 1, i + 1))
print('this node has {} edgs.'.format(
nrow[1]))
if nattrs != []:
nattrs = np.stack(nattrs)
g.ndata['attr'] = nattrs
self.nattrs_flag = True
else:
nattrs = None
g.ndata['label'] = np.array(nlabels)
if len(self.nlabel_dict) > 1:
self.nlabels_flag = True
assert len(g) == n_nodes
# update statistics of graphs
self.n += n_nodes
self.m += m_edges
self.graphs.append(g)
# if no attr
if not self.nattrs_flag:
print('there are no node features in this dataset!')
label2idx = {}
# generate node attr by node degree
if self.degree_as_nlabel:
print('generate node features by node degree...')
nlabel_set = set([])
for g in self.graphs:
# actually this label shouldn't be updated
# in case users want to keep it
# but usually no features means no labels, fine.
g.ndata['label'] = g.in_degrees()
# extracting unique node labels
nlabel_set = nlabel_set.union(set(g.ndata['label']))
nlabel_set = list(nlabel_set)
# in case the labels/degrees are not continuous number
self.ndegree_dict = {
nlabel_set[i]: i
for i in range(len(nlabel_set))
}
label2idx = self.ndegree_dict
# generate node attr by node label
else:
print('generate node features by node label...')
label2idx = self.nlabel_dict
for g in self.graphs:
g.ndata['attr'] = np.zeros((
g.number_of_nodes(), len(label2idx)))
g.ndata['attr'][range(g.number_of_nodes(
)), [label2idx[nl.item()] for nl in g.ndata['label']]] = 1
# after load, get the #classes and #dim
self.gclasses = len(self.glabel_dict)
self.nclasses = len(self.nlabel_dict)
self.eclasses = len(self.elabel_dict)
self.dim_nfeats = len(self.graphs[0].ndata['attr'][0])
print('Done.')
print(
"""
-------- Data Statistics --------'
#Graphs: %d
#Graph Classes: %d
#Nodes: %d
#Node Classes: %d
#Node Features Dim: %d
#Edges: %d
#Edge Classes: %d
Avg. of #Nodes: %.2f
Avg. of #Edges: %.2f
Graph Relabeled: %s
Node Relabeled: %s
Degree Relabeled(If degree_as_nlabel=True): %s \n """ % (
self.N, self.gclasses, self.n, self.nclasses,
self.dim_nfeats, self.m, self.eclasses,
self.n / self.N, self.m / self.N, self.glabel_dict,
self.nlabel_dict, self.ndegree_dict))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment