Unverified Commit 9bcce7be authored by Xiangkun Hu's avatar Xiangkun Hu Committed by GitHub
Browse files

[Dataset] PPIDataset (#1906)



* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c960cfa698a7bb946dad132bf52c2d0c8a.

* Revert "Revert "PPIDataset""

This reverts commit 6938a4cbe3ac6e38d3e0188b5699e5c952a6102e.

* update doc string
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 4f7bd0d0
...@@ -30,34 +30,26 @@ def main(args): ...@@ -30,34 +30,26 @@ def main(args):
# load and preprocess dataset # load and preprocess dataset
data = load_data(args) data = load_data(args)
g = data.g
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
labels = g.ndata['label']
train_nid = np.nonzero(data.train_mask)[0].astype(np.int64) train_nid = np.nonzero(train_mask.data.numpy())[0].astype(np.int64)
# Normalize features # Normalize features
if args.normalize: if args.normalize:
train_feats = data.features[train_nid] feats = g.ndata['feat']
train_feats = feats[train_mask]
scaler = sklearn.preprocessing.StandardScaler() scaler = sklearn.preprocessing.StandardScaler()
scaler.fit(train_feats) scaler.fit(train_feats.data.numpy())
features = scaler.transform(data.features) features = scaler.transform(feats.data.numpy())
else: g.ndata['feat'] = torch.FloatTensor(features)
features = data.features
features = torch.FloatTensor(features) in_feats = g.ndata['feat'].shape[1]
if not multitask: n_classes = data.num_classes
labels = torch.LongTensor(data.labels) n_edges = g.number_of_edges()
else:
labels = torch.FloatTensor(data.labels)
if hasattr(torch, 'BoolTensor'):
train_mask = torch.BoolTensor(data.train_mask)
val_mask = torch.BoolTensor(data.val_mask)
test_mask = torch.BoolTensor(data.test_mask)
else:
train_mask = torch.ByteTensor(data.train_mask)
val_mask = torch.ByteTensor(data.val_mask)
test_mask = torch.ByteTensor(data.test_mask)
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = data.graph.number_of_edges()
n_train_samples = train_mask.int().sum().item() n_train_samples = train_mask.int().sum().item()
n_val_samples = val_mask.int().sum().item() n_val_samples = val_mask.int().sum().item()
...@@ -74,17 +66,12 @@ def main(args): ...@@ -74,17 +66,12 @@ def main(args):
n_val_samples, n_val_samples,
n_test_samples)) n_test_samples))
# create GCN model # create GCN model
g = data.graph
g = dgl.graph(g)
if args.self_loop and not args.dataset.startswith('reddit'): if args.self_loop and not args.dataset.startswith('reddit'):
g = dgl.remove_self_loop(g) g = dgl.remove_self_loop(g)
g = dgl.add_self_loop(g) g = dgl.add_self_loop(g)
print("adding self-loop edges") print("adding self-loop edges")
# metis only support int64 graph # metis only support int64 graph
g = g.long() g = g.long()
g.ndata['features'] = features
g.ndata['labels'] = labels
g.ndata['train_mask'] = train_mask
cluster_iterator = ClusterIter( cluster_iterator = ClusterIter(
args.dataset, g, args.psize, args.batch_size, train_nid, use_pp=args.use_pp) args.dataset, g, args.psize, args.batch_size, train_nid, use_pp=args.use_pp)
...@@ -99,9 +86,8 @@ def main(args): ...@@ -99,9 +86,8 @@ def main(args):
test_mask = test_mask.cuda() test_mask = test_mask.cuda()
g = g.to(args.gpu) g = g.to(args.gpu)
print(torch.cuda.get_device_name(0)) print('labels shape:', g.ndata['label'].shape)
print('labels shape:', labels.shape) print("features shape, ", g.ndata['feat'].shape)
print("features shape, ", features.shape)
model = GraphSAGE(in_feats, model = GraphSAGE(in_feats,
args.n_hidden, args.n_hidden,
...@@ -144,11 +130,12 @@ def main(args): ...@@ -144,11 +130,12 @@ def main(args):
for epoch in range(args.n_epochs): for epoch in range(args.n_epochs):
for j, cluster in enumerate(cluster_iterator): for j, cluster in enumerate(cluster_iterator):
# sync with upper level training graph # sync with upper level training graph
if cuda:
cluster = cluster.to(torch.cuda.current_device()) cluster = cluster.to(torch.cuda.current_device())
model.train() model.train()
# forward # forward
pred = model(cluster) pred = model(cluster)
batch_labels = cluster.ndata['labels'] batch_labels = cluster.ndata['label']
batch_train_mask = cluster.ndata['train_mask'] batch_train_mask = cluster.ndata['train_mask']
loss = loss_f(pred[batch_train_mask], loss = loss_f(pred[batch_train_mask],
batch_labels[batch_train_mask]) batch_labels[batch_train_mask])
......
...@@ -90,7 +90,7 @@ class GraphSAGE(nn.Module): ...@@ -90,7 +90,7 @@ class GraphSAGE(nn.Module):
dropout=dropout, use_pp=False, use_lynorm=False)) dropout=dropout, use_pp=False, use_lynorm=False))
def forward(self, g): def forward(self, g):
h = g.ndata['features'] h = g.ndata['feat']
for layer in self.layers: for layer in self.layers:
h = layer(g, h) h = layer(g, h)
return h return h
...@@ -57,21 +57,21 @@ class ClusterIter(object): ...@@ -57,21 +57,21 @@ class ClusterIter(object):
def precalc(self, g): def precalc(self, g):
norm = self.get_norm(g) norm = self.get_norm(g)
g.ndata['norm'] = norm g.ndata['norm'] = norm
features = g.ndata['features'] features = g.ndata['feat']
print("features shape, ", features.shape) print("features shape, ", features.shape)
with torch.no_grad(): with torch.no_grad():
g.update_all(fn.copy_src(src='features', out='m'), g.update_all(fn.copy_src(src='feat', out='m'),
fn.sum(msg='m', out='features'), fn.sum(msg='m', out='feat'),
None) None)
pre_feats = g.ndata['features'] * norm pre_feats = g.ndata['feat'] * norm
# use graphsage embedding aggregation style # use graphsage embedding aggregation style
g.ndata['features'] = torch.cat([features, pre_feats], dim=1) g.ndata['feat'] = torch.cat([features, pre_feats], dim=1)
# use one side normalization # use one side normalization
def get_norm(self, g): def get_norm(self, g):
norm = 1. / g.in_degrees().float().unsqueeze(1) norm = 1. / g.in_degrees().float().unsqueeze(1)
norm[torch.isinf(norm)] = 0 norm[torch.isinf(norm)] = 0
norm = norm.to(self.g.ndata['features'].device) norm = norm.to(self.g.ndata['feat'].device)
return norm return norm
def __len__(self): def __len__(self):
......
...@@ -60,22 +60,23 @@ def evaluate(model, g, labels, mask, multitask=False): ...@@ -60,22 +60,23 @@ def evaluate(model, g, labels, mask, multitask=False):
def load_data(args): def load_data(args):
'''Wraps the dgl's load_data utility to handle ppi special case''' '''Wraps the dgl's load_data utility to handle ppi special case'''
DataType = namedtuple('Dataset', ['num_classes', 'g'])
if args.dataset != 'ppi': if args.dataset != 'ppi':
return _load_data(args) dataset = _load_data(args)
data = DataType(g=dataset[0], num_classes=dataset.num_classes)
return data
train_dataset = PPIDataset('train') train_dataset = PPIDataset('train')
train_graph = dgl.batch([train_dataset[i] for i in range(len(train_dataset))], edge_attrs=None, node_attrs=None)
val_dataset = PPIDataset('valid') val_dataset = PPIDataset('valid')
val_graph = dgl.batch([val_dataset[i] for i in range(len(val_dataset))], edge_attrs=None, node_attrs=None)
test_dataset = PPIDataset('test') test_dataset = PPIDataset('test')
PPIDataType = namedtuple('PPIDataset', ['train_mask', 'test_mask', test_graph = dgl.batch([test_dataset[i] for i in range(len(test_dataset))], edge_attrs=None, node_attrs=None)
'val_mask', 'features', 'labels', 'num_labels', 'graph'])
G = dgl.batch( G = dgl.batch(
[train_dataset.graph, val_dataset.graph, test_dataset.graph], edge_attrs=None, node_attrs=None) [train_graph, val_graph, test_graph], edge_attrs=None, node_attrs=None)
G = G.to_networkx()
# hack to dodge the potential bugs of to_networkx train_nodes_num = train_graph.number_of_nodes()
for (n1, n2, d) in G.edges(data=True): test_nodes_num = test_graph.number_of_nodes()
d.clear() val_nodes_num = val_graph.number_of_nodes()
train_nodes_num = train_dataset.graph.number_of_nodes()
test_nodes_num = test_dataset.graph.number_of_nodes()
val_nodes_num = val_dataset.graph.number_of_nodes()
nodes_num = G.number_of_nodes() nodes_num = G.number_of_nodes()
assert(nodes_num == (train_nodes_num + test_nodes_num + val_nodes_num)) assert(nodes_num == (train_nodes_num + test_nodes_num + val_nodes_num))
# construct mask # construct mask
...@@ -87,13 +88,9 @@ def load_data(args): ...@@ -87,13 +88,9 @@ def load_data(args):
test_mask = mask.copy() test_mask = mask.copy()
test_mask[-test_nodes_num:] = True test_mask[-test_nodes_num:] = True
# construct features G.ndata['train_mask'] = torch.tensor(train_mask, dtype=torch.bool)
features = np.concatenate( G.ndata['val_mask'] = torch.tensor(val_mask, dtype=torch.bool)
[train_dataset.features, val_dataset.features, test_dataset.features], axis=0) G.ndata['test_mask'] = torch.tensor(test_mask, dtype=torch.bool)
labels = np.concatenate(
[train_dataset.labels, val_dataset.labels, test_dataset.labels], axis=0)
data = PPIDataType(graph=G, train_mask=train_mask, test_mask=test_mask, data = DataType(g=G, num_classes=train_dataset.num_labels)
val_mask=val_mask, features=features, labels=labels, num_labels=121)
return data return data
...@@ -17,15 +17,12 @@ import torch.nn.functional as F ...@@ -17,15 +17,12 @@ import torch.nn.functional as F
import argparse import argparse
from sklearn.metrics import f1_score from sklearn.metrics import f1_score
from gat import GAT from gat import GAT
from dgl.data.ppi import LegacyPPIDataset from dgl.data.ppi import PPIDataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
def collate(sample): def collate(graphs):
graphs, feats, labels =map(list, zip(*sample))
graph = dgl.batch(graphs) graph = dgl.batch(graphs)
feats = torch.from_numpy(np.concatenate(feats)) return graph
labels = torch.from_numpy(np.concatenate(labels))
return graph, feats, labels
def evaluate(feats, model, subgraph, labels, loss_fcn): def evaluate(feats, model, subgraph, labels, loss_fcn):
with torch.no_grad(): with torch.no_grad():
...@@ -54,15 +51,15 @@ def main(args): ...@@ -54,15 +51,15 @@ def main(args):
# define loss function # define loss function
loss_fcn = torch.nn.BCEWithLogitsLoss() loss_fcn = torch.nn.BCEWithLogitsLoss()
# create the dataset # create the dataset
train_dataset = LegacyPPIDataset(mode='train') train_dataset = PPIDataset(mode='train')
valid_dataset = LegacyPPIDataset(mode='valid') valid_dataset = PPIDataset(mode='valid')
test_dataset = LegacyPPIDataset(mode='test') test_dataset = PPIDataset(mode='test')
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=collate) valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=collate)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate) test_dataloader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate)
n_classes = train_dataset.labels.shape[1] g = train_dataset[0]
num_feats = train_dataset.features.shape[1] n_classes = train_dataset.num_labels
g = train_dataset.graph num_feats = g.ndata['feat'].shape[1]
g = g.to(device) g = g.to(device)
heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads] heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]
# define the model # define the model
...@@ -83,16 +80,13 @@ def main(args): ...@@ -83,16 +80,13 @@ def main(args):
for epoch in range(args.epochs): for epoch in range(args.epochs):
model.train() model.train()
loss_list = [] loss_list = []
for batch, data in enumerate(train_dataloader): for batch, subgraph in enumerate(train_dataloader):
subgraph, feats, labels = data
subgraph = subgraph.to(device) subgraph = subgraph.to(device)
feats = feats.to(device)
labels = labels.to(device)
model.g = subgraph model.g = subgraph
for layer in model.gat_layers: for layer in model.gat_layers:
layer.g = subgraph layer.g = subgraph
logits = model(feats.float()) logits = model(subgraph.ndata['feat'].float())
loss = loss_fcn(logits, labels.float()) loss = loss_fcn(logits, subgraph.ndata['label'])
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
......
"""PPI Dataset. """ PPIDataset for inductive learning. """
(zhang hao): Used for inductive learning.
"""
import json import json
import numpy as np import numpy as np
import networkx as nx import networkx as nx
from networkx.readwrite import json_graph from networkx.readwrite import json_graph
import os
from .utils import download, extract_archive, get_download_dir, _get_dgl_url from .dgl_dataset import DGLBuiltinDataset
from ..utils import retry_method_with_fix from .utils import _get_dgl_url, save_graphs, save_info, load_info, load_graphs, deprecate_property
from .. import backend as F
from ..convert import from_networkx from ..convert import from_networkx
_url = 'dataset/ppi.zip'
class PPIDataset(DGLBuiltinDataset):
class PPIDataset(object): r""" Protein-Protein Interaction dataset for inductive node classification
"""A toy Protein-Protein Interaction network dataset.
.. deprecated:: 0.5.0
Adapted from https://github.com/williamleif/GraphSAGE/tree/master/example_data. `lables` is deprecated, it is replaced by:
>>> dataset = PPIDataset()
The dataset contains 24 graphs. The average number of nodes per graph >>> for g in dataset:
is 2372. Each node has 50 features and 121 labels. .... labels = g.ndata['label']
....
We use 20 graphs for training, 2 for validation and 2 for testing. >>>
""" `features` is deprecated, it is replaced by:
def __init__(self, mode): >>> dataset = PPIDataset()
"""Initialize the dataset. >>> for g in dataset:
.... features = g.ndata['feat']
Paramters ....
--------- >>>
A toy Protein-Protein Interaction network dataset. The dataset contains
24 graphs. The average number of nodes per graph is 2372. Each node has
50 features and 121 labels. 20 graphs for training, 2 for validation
and 2 for testing.
Reference: http://snap.stanford.edu/graphsage/
PPI dataset statistics:
Train examples: 20
Valid examples: 2
Test examples: 2
Parameters
----------
mode : str mode : str
('train', 'valid', 'test'). Must be one of ('train', 'valid', 'test').
Default: 'train'
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset.
Default: False
verbose: bool
Whether to print out progress information.
Default: True.
Attributes
----------
num_labels : int
Number of labels for each node
labels : Tensor
Node labels
features : Tensor
Node features
Examples
--------
>>> dataset = PPIDataset(mode='valid')
>>> num_labels = dataset.num_labels
>>> for g in dataset:
.... feat = g.ndata['feat']
.... label = g.ndata['label']
.... # your code here
>>>
""" """
def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False):
assert mode in ['train', 'valid', 'test'] assert mode in ['train', 'valid', 'test']
self.mode = mode self.mode = mode
self._name = 'ppi' _url = _get_dgl_url('dataset/ppi.zip')
self._dir = get_download_dir() super(PPIDataset, self).__init__(name='ppi',
self._zip_file_path = '{}/{}.zip'.format(self._dir, self._name) url=_url,
self._load() raw_dir=raw_dir,
self._preprocess() force_reload=force_reload,
verbose=verbose)
def _download(self):
download(_get_dgl_url(_url), path=self._zip_file_path) def process(self):
extract_archive(self._zip_file_path, graph_file = os.path.join(self.save_path, '{}_graph.json'.format(self.mode))
'{}/{}'.format(self._dir, self._name)) label_file = os.path.join(self.save_path, '{}_labels.npy'.format(self.mode))
feat_file = os.path.join(self.save_path, '{}_feats.npy'.format(self.mode))
@retry_method_with_fix(_download) graph_id_file = os.path.join(self.save_path, '{}_graph_id.npy'.format(self.mode))
def _load(self):
"""Loads input data. g_data = json.load(open(graph_file))
self._labels = np.load(label_file)
train/test/valid_graph.json => the graph data used for training, self._feats = np.load(feat_file)
test and validation as json format;
train/test/valid_feats.npy => the feature vectors of nodes as
numpy.ndarry object, it's shape is [n, v],
n is the number of nodes, v is the feature's dimension;
train/test/valid_labels.npy=> the labels of the input nodes, it
is a numpy ndarry, it's like[[0, 0, 1, ... 0],
[0, 1, 1, 0 ...1]], shape of it is n*h, n is the number of nodes,
h is the label's dimension;
train/test/valid/_graph_id.npy => the element in it indicates which
graph the nodes belong to, it is a one dimensional numpy.ndarray
object and the length of it is equal the number of nodes,
it's like [1, 1, 2, 1...20].
"""
print('Loading G...')
if self.mode == 'train':
with open('{}/ppi/train_graph.json'.format(self._dir)) as jsonfile:
g_data = json.load(jsonfile)
self.labels = np.load('{}/ppi/train_labels.npy'.format(self._dir))
self.features = np.load('{}/ppi/train_feats.npy'.format(self._dir))
self.graph = from_networkx(nx.DiGraph(json_graph.node_link_graph(g_data)))
self.graph_id = np.load('{}/ppi/train_graph_id.npy'.format(self._dir))
if self.mode == 'valid':
with open('{}/ppi/valid_graph.json'.format(self._dir)) as jsonfile:
g_data = json.load(jsonfile)
self.labels = np.load('{}/ppi/valid_labels.npy'.format(self._dir))
self.features = np.load('{}/ppi/valid_feats.npy'.format(self._dir))
self.graph = from_networkx(nx.DiGraph(json_graph.node_link_graph(g_data)))
self.graph_id = np.load('{}/ppi/valid_graph_id.npy'.format(self._dir))
if self.mode == 'test':
with open('{}/ppi/test_graph.json'.format(self._dir)) as jsonfile:
g_data = json.load(jsonfile)
self.labels = np.load('{}/ppi/test_labels.npy'.format(self._dir))
self.features = np.load('{}/ppi/test_feats.npy'.format(self._dir))
self.graph = from_networkx(nx.DiGraph(json_graph.node_link_graph(g_data))) self.graph = from_networkx(nx.DiGraph(json_graph.node_link_graph(g_data)))
self.graph_id = np.load('{}/ppi/test_graph_id.npy'.format(self._dir)) graph_id = np.load(graph_id_file)
def _preprocess(self): # lo, hi means the range of graph ids for different portion of the dataset,
if self.mode == 'train': # 20 graphs for training, 2 for validation and 2 for testing.
self.train_mask_list = [] lo, hi = 1, 21
self.train_graphs = []
self.train_labels = []
for train_graph_id in range(1, 21):
train_graph_mask = np.where(self.graph_id == train_graph_id)[0]
self.train_mask_list.append(train_graph_mask)
self.train_graphs.append(self.graph.subgraph(train_graph_mask))
self.train_labels.append(self.labels[train_graph_mask])
if self.mode == 'valid': if self.mode == 'valid':
self.valid_mask_list = [] lo, hi = 21, 23
self.valid_graphs = [] elif self.mode == 'test':
self.valid_labels = [] lo, hi = 23, 25
for valid_graph_id in range(21, 23):
valid_graph_mask = np.where(self.graph_id == valid_graph_id)[0] graph_masks = []
self.valid_mask_list.append(valid_graph_mask) self.graphs = []
self.valid_graphs.append(self.graph.subgraph(valid_graph_mask)) for g_id in range(lo, hi):
self.valid_labels.append(self.labels[valid_graph_mask]) g_mask = np.where(graph_id == g_id)[0]
if self.mode == 'test': graph_masks.append(g_mask)
self.test_mask_list = [] g = self.graph.subgraph(g_mask)
self.test_graphs = [] g.ndata['feat'] = F.tensor(self._feats[g_mask], dtype=F.data_type_dict['float32'])
self.test_labels = [] g.ndata['label'] = F.tensor(self._labels[g_mask], dtype=F.data_type_dict['float32'])
for test_graph_id in range(23, 25): self.graphs.append(g)
test_graph_mask = np.where(self.graph_id == test_graph_id)[0]
self.test_mask_list.append(test_graph_mask) def has_cache(self):
self.test_graphs.append(self.graph.subgraph(test_graph_mask)) graph_list_path = os.path.join(self.save_path, '{}_dgl_graph_list.bin'.format(self.mode))
self.test_labels.append(self.labels[test_graph_mask]) g_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode))
info_path = os.path.join(self.save_path, '{}_info.pkl'.format(self.mode))
return os.path.exists(graph_list_path) and os.path.exists(g_path) and os.path.exists(info_path)
def save(self):
graph_list_path = os.path.join(self.save_path, '{}_dgl_graph_list.bin'.format(self.mode))
g_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode))
info_path = os.path.join(self.save_path, '{}_info.pkl'.format(self.mode))
save_graphs(graph_list_path, self.graphs)
save_graphs(g_path, self.graph)
save_info(info_path, {'labels': self._labels, 'feats': self._feats})
def load(self):
graph_list_path = os.path.join(self.save_path, '{}_dgl_graph_list.bin'.format(self.mode))
g_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode))
info_path = os.path.join(self.save_path, '{}_info.pkl'.format(self.mode))
self.graphs = load_graphs(graph_list_path)[0]
g, _ = load_graphs(g_path)
self.graph = g[0]
info = load_info(info_path)
self._labels = info['labels']
self._feats = info['feats']
@property
def num_labels(self):
return 121
@property
def labels(self):
deprecate_property('dataset.labels', 'dataset.graphs[i].ndata[\'label\']')
return self._labels
@property
def features(self):
deprecate_property('dataset.features', 'dataset.graphs[i].ndata[\'feat\']')
return self._feats
def __len__(self): def __len__(self):
"""Return number of samples in this dataset.""" """Return number of samples in this dataset."""
if self.mode == 'train': return len(self.graphs)
return len(self.train_mask_list)
if self.mode == 'valid':
return len(self.valid_mask_list)
if self.mode == 'test':
return len(self.test_mask_list)
def __getitem__(self, item): def __getitem__(self, item):
"""Get the i^th sample. """Get the item^th sample.
Paramters Parameters
--------- ---------
idx : int item : int
The sample index. The sample index.
Returns Returns
------- -------
(dgl.DGLGraph, ndarray) dgl.DGLGraph
The graph, and its label. graph structure, node features and node labels.
- ndata['feat']: node features
- ndata['label']: nodel labels
""" """
if self.mode == 'train': return self.graphs[item]
g = self.train_graphs[item]
g.ndata['feat'] = self.features[self.train_mask_list[item]]
label = self.train_labels[item]
elif self.mode == 'valid':
g = self.valid_graphs[item]
g.ndata['feat'] = self.features[self.valid_mask_list[item]]
label = self.valid_labels[item]
elif self.mode == 'test':
g = self.test_graphs[item]
g.ndata['feat'] = self.features[self.test_mask_list[item]]
label = self.test_labels[item]
return g, label
class LegacyPPIDataset(PPIDataset): class LegacyPPIDataset(PPIDataset):
...@@ -156,7 +180,7 @@ class LegacyPPIDataset(PPIDataset): ...@@ -156,7 +180,7 @@ class LegacyPPIDataset(PPIDataset):
""" """
def __getitem__(self, item): def __getitem__(self, item):
"""Get the i^th sample. """Get the item^th sample.
Paramters Paramters
--------- ---------
...@@ -165,12 +189,8 @@ class LegacyPPIDataset(PPIDataset): ...@@ -165,12 +189,8 @@ class LegacyPPIDataset(PPIDataset):
Returns Returns
------- -------
(dgl.DGLGraph, ndarray, ndarray) (dgl.DGLGraph, Tensor, Tensor)
The graph, features and its label. The graph, features and its label.
""" """
if self.mode == 'train':
return self.train_graphs[item], self.features[self.train_mask_list[item]], self.train_labels[item] return self.graphs[item], self.graphs[item].ndata['feat'], self.graphs[item].ndata['label']
if self.mode == 'valid':
return self.valid_graphs[item], self.features[self.valid_mask_list[item]], self.valid_labels[item]
if self.mode == 'test':
return self.test_graphs[item], self.features[self.test_mask_list[item]], self.test_labels[item]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment