Commit c03046a0 authored by Mufei Li's avatar Mufei Li Committed by Quan (Andy) Gan
Browse files

[Model][Hetero] HAN (#868)

* Add HAN

* Fix

* WIP; load raw ACM dataset

* DGL's own preprocessing with metapath coalescer

* various fixes

* comparison against simple logistic regression

* rename

* fix test
parent e16667bf
# Heterogeneous Graph Attention Network (HAN) with DGL
This is an attempt to implement HAN with DGL's latest APIs for heterogeneous graphs.
The authors' implementation can be found [here](https://github.com/Jhy1993/HAN).
## Usage
`python main.py` for reproducing HAN's work on their dataset.
`python main.py --hetero` for reproducing HAN's work on DGL's own dataset.
## Performance
Reference performance numbers for the ACM dataset:
| | micro f1 score | macro f1 score |
| ------------------- | -------------- | -------------- |
| Paper | 89.22 | 89.40 |
| DGL | 88.99 | 89.02 |
| Softmax regression (own dataset) | 89.66 | 89.62 |
| DGL (own dataset) | 91.51 | 91.66 |
We ran a softmax regression to check the easiness of our own dataset. HAN did show some improvements.
import torch
from sklearn.metrics import f1_score
from utils import load_data, EarlyStopping
def score(logits, labels):
_, indices = torch.max(logits, dim=1)
prediction = indices.long().cpu().numpy()
labels = labels.cpu().numpy()
accuracy = (prediction == labels).sum() / len(prediction)
micro_f1 = f1_score(labels, prediction, average='micro')
macro_f1 = f1_score(labels, prediction, average='macro')
return accuracy, micro_f1, macro_f1
def evaluate(model, g, features, labels, mask, loss_func):
model.eval()
with torch.no_grad():
logits = model(g, features)
loss = loss_func(logits[mask], labels[mask])
accuracy, micro_f1, macro_f1 = score(logits[mask], labels[mask])
return loss, accuracy, micro_f1, macro_f1
def main(args):
# If args['hetero'] is True, g would be a heterogeneous graph.
# Otherwise, it will be a list of homogeneous graphs.
g, features, labels, num_classes, train_idx, val_idx, test_idx, train_mask, \
val_mask, test_mask = load_data(args['dataset'])
features = features.to(args['device'])
labels = labels.to(args['device'])
train_mask = train_mask.to(args['device'])
val_mask = val_mask.to(args['device'])
test_mask = test_mask.to(args['device'])
if args['hetero']:
from model_hetero import HAN
model = HAN(meta_paths=[['pa', 'ap'], ['pf', 'fp']],
in_size=features.shape[1],
hidden_size=args['hidden_units'],
out_size=num_classes,
num_heads=args['num_heads'],
dropout=args['dropout']).to(args['device'])
else:
from model import HAN
model = HAN(num_meta_paths=len(g),
in_size=features.shape[1],
hidden_size=args['hidden_units'],
out_size=num_classes,
num_heads=args['num_heads'],
dropout=args['dropout']).to(args['device'])
stopper = EarlyStopping(patience=args['patience'])
loss_fcn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'],
weight_decay=args['weight_decay'])
for epoch in range(args['num_epochs']):
model.train()
logits = model(g, features)
loss = loss_fcn(logits[train_mask], labels[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_acc, train_micro_f1, train_macro_f1 = score(logits[train_mask], labels[train_mask])
val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(model, g, features, labels, val_mask, loss_fcn)
early_stop = stopper.step(val_loss.data.item(), val_acc, model)
print('Epoch {:d} | Train Loss {:.4f} | Train Micro f1 {:.4f} | Train Macro f1 {:.4f} | '
'Val Loss {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}'.format(
epoch + 1, loss.item(), train_micro_f1, train_macro_f1, val_loss.item(), val_micro_f1, val_macro_f1))
if early_stop:
break
stopper.load_checkpoint(model)
test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate(model, g, features, labels, test_mask, loss_fcn)
print('Test loss {:.4f} | Test Micro f1 {:.4f} | Test Macro f1 {:.4f}'.format(
test_loss.item(), test_micro_f1, test_macro_f1))
if __name__ == '__main__':
import argparse
from utils import setup
parser = argparse.ArgumentParser('HAN')
parser.add_argument('-s', '--seed', type=int, default=1,
help='Random seed')
parser.add_argument('-ld', '--log-dir', type=str, default='results',
help='Dir for saving training results')
parser.add_argument('--hetero', action='store_true',
help='Use metapath coalescing with DGL\'s own dataset')
args = parser.parse_args().__dict__
args = setup(args)
main(args)
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GATConv
class SemanticAttention(nn.Module):
def __init__(self, in_size, hidden_size=128):
super(SemanticAttention, self).__init__()
self.project = nn.Sequential(
nn.Linear(in_size, hidden_size),
nn.Tanh(),
nn.Linear(hidden_size, 1, bias=False)
)
def forward(self, z):
w = self.project(z)
beta = torch.softmax(w, dim=1)
return (beta * z).sum(1)
class HANLayer(nn.Module):
"""
HAN layer.
Arguments
---------
num_meta_paths : number of homogeneous graphs generated from the metapaths.
in_size : input feature dimension
out_size : output feature dimension
layer_num_heads : number of attention heads
dropout : Dropout probability
Inputs
------
g : list[DGLGraph]
List of graphs
h : tensor
Input features
Outputs
-------
tensor
The output feature
"""
def __init__(self, num_meta_paths, in_size, out_size, layer_num_heads, dropout):
super(HANLayer, self).__init__()
# One GAT layer for each meta path based adjacency matrix
self.gat_layers = nn.ModuleList()
for i in range(num_meta_paths):
self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads,
dropout, dropout, activation=F.elu))
self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads)
self.num_meta_paths = num_meta_paths
def forward(self, gs, h):
semantic_embeddings = []
for i, g in enumerate(gs):
semantic_embeddings.append(self.gat_layers[i](g, h).flatten(1))
semantic_embeddings = torch.stack(semantic_embeddings, dim=1) # (N, M, D * K)
return self.semantic_attention(semantic_embeddings) # (N, D * K)
class HAN(nn.Module):
def __init__(self, num_meta_paths, in_size, hidden_size, out_size, num_heads, dropout):
super(HAN, self).__init__()
self.layers = nn.ModuleList()
self.layers.append(HANLayer(num_meta_paths, in_size, hidden_size, num_heads[0], dropout))
for l in range(1, len(num_heads)):
self.layers.append(HANLayer(num_meta_paths, hidden_size * num_heads[l-1],
hidden_size, num_heads[l], dropout))
self.predict = nn.Linear(hidden_size * num_heads[-1], out_size)
def forward(self, g, h):
for gnn in self.layers:
h = gnn(g, h)
return self.predict(h)
"""This model shows an example of using dgl.metapath_reachable_graph on the original heterogeneous
graph.
Because the original HAN implementation only gives the preprocessed homogeneous graph, this model
could not reproduce the result in HAN as they did not provide the preprocessing code, and we
constructed another dataset from ACM with a different set of papers, connections, features and
labels.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.nn.pytorch import GATConv
class SemanticAttention(nn.Module):
def __init__(self, in_size, hidden_size=128):
super(SemanticAttention, self).__init__()
self.project = nn.Sequential(
nn.Linear(in_size, hidden_size),
nn.Tanh(),
nn.Linear(hidden_size, 1, bias=False)
)
def forward(self, z):
w = self.project(z)
beta = torch.softmax(w, dim=1)
return (beta * z).sum(1)
class HANLayer(nn.Module):
"""
HAN layer.
Arguments
---------
meta_paths : list of metapaths, each as a list of edge types
in_size : input feature dimension
out_size : output feature dimension
layer_num_heads : number of attention heads
dropout : Dropout probability
Inputs
------
g : DGLHeteroGraph
The heterogeneous graph
h : tensor
Input features
Outputs
-------
tensor
The output feature
"""
def __init__(self, meta_paths, in_size, out_size, layer_num_heads, dropout):
super(HANLayer, self).__init__()
# One GAT layer for each meta path based adjacency matrix
self.gat_layers = nn.ModuleList()
for i in range(len(meta_paths)):
self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads,
dropout, dropout, activation=F.elu))
self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads)
self.meta_paths = list(tuple(meta_path) for meta_path in meta_paths)
self._cached_graph = None
self._cached_coalesced_graph = {}
def forward(self, g, h):
semantic_embeddings = []
if self._cached_graph is None or self._cached_graph is not g:
self._cached_graph = g
self._cached_coalesced_graph.clear()
for meta_path in self.meta_paths:
self._cached_coalesced_graph[meta_path] = dgl.metapath_reachable_graph(
g, meta_path)
for i, meta_path in enumerate(self.meta_paths):
new_g = self._cached_coalesced_graph[meta_path]
semantic_embeddings.append(self.gat_layers[i](new_g, h).flatten(1))
semantic_embeddings = torch.stack(semantic_embeddings, dim=1) # (N, M, D * K)
return self.semantic_attention(semantic_embeddings) # (N, D * K)
class HAN(nn.Module):
def __init__(self, meta_paths, in_size, hidden_size, out_size, num_heads, dropout):
super(HAN, self).__init__()
self.layers = nn.ModuleList()
self.layers.append(HANLayer(meta_paths, in_size, hidden_size, num_heads[0], dropout))
for l in range(1, len(num_heads)):
self.layers.append(HANLayer(meta_paths, hidden_size * num_heads[l-1],
hidden_size, num_heads[l], dropout))
self.predict = nn.Linear(hidden_size * num_heads[-1], out_size)
def forward(self, g, h):
for gnn in self.layers:
h = gnn(g, h)
return self.predict(h)
import datetime
import dgl
import errno
import numpy as np
import os
import pickle
import random
import torch
from dgl.data.utils import download, get_download_dir, _get_dgl_url
from pprint import pprint
from scipy import sparse
from scipy import io as sio
def set_random_seed(seed=0):
"""Set random seed.
Parameters
----------
seed : int
Random seed to use
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def mkdir_p(path, log=True):
"""Create a directory for the specified path.
Parameters
----------
path : str
Path name
log : bool
Whether to print result for directory creation
"""
try:
os.makedirs(path)
if log:
print('Created directory {}'.format(path))
except OSError as exc:
if exc.errno == errno.EEXIST and os.path.isdir(path) and log:
print('Directory {} already exists.'.format(path))
else:
raise
def get_date_postfix():
"""Get a date based postfix for directory name.
Returns
-------
post_fix : str
"""
dt = datetime.datetime.now()
post_fix = '{}_{:02d}-{:02d}-{:02d}'.format(
dt.date(), dt.hour, dt.minute, dt.second)
return post_fix
def setup_log_dir(args, sampling=False):
"""Name and create directory for logging.
Parameters
----------
args : dict
Configuration
Returns
-------
log_dir : str
Path for logging directory
sampling : bool
Whether we are using sampling based training
"""
date_postfix = get_date_postfix()
log_dir = os.path.join(
args['log_dir'],
'{}_{}'.format(args['dataset'], date_postfix))
if sampling:
log_dir = log_dir + '_sampling'
mkdir_p(log_dir)
return log_dir
# The configuration below is from the paper.
default_configure = {
'lr': 0.005, # Learning rate
'num_heads': [8], # Number of attention heads for node-level attention
'hidden_units': 8,
'dropout': 0.6,
'weight_decay': 0.001,
'num_epochs': 200,
'patience': 100
}
sampling_configure = {
'batch_size': 20
}
def setup(args):
args.update(default_configure)
set_random_seed(args['seed'])
args['dataset'] = 'ACMRaw' if args['hetero'] else 'ACM'
args['device'] = 'cuda: 0' if torch.cuda.is_available() else 'cpu'
args['log_dir'] = setup_log_dir(args)
return args
def setup_for_sampling(args):
args.update(default_configure)
args.update(sampling_configure)
set_random_seed()
args['device'] = 'cuda: 0' if torch.cuda.is_available() else 'cpu'
args['log_dir'] = setup_log_dir(args, sampling=True)
return args
def get_binary_mask(total_size, indices):
mask = torch.zeros(total_size)
mask[indices] = 1
return mask.byte()
def load_acm(remove_self_loop):
url = 'dataset/ACM3025.pkl'
data_path = get_download_dir() + '/ACM3025.pkl'
download(_get_dgl_url(url), path=data_path)
with open(data_path, 'rb') as f:
data = pickle.load(f)
labels, features = torch.from_numpy(data['label'].todense()).long(), \
torch.from_numpy(data['feature'].todense()).float()
num_classes = labels.shape[1]
labels = labels.nonzero()[:, 1]
if remove_self_loop:
num_nodes = data['label'].shape[0]
data['PAP'] = sparse.csr_matrix(data['PAP'] - np.eye(num_nodes))
data['PLP'] = sparse.csr_matrix(data['PLP'] - np.eye(num_nodes))
# Adjacency matrices for meta path based neighbors
# (Mufei): I verified both of them are binary adjacency matrices with self loops
author_g = dgl.graph(data['PAP'], ntype='paper', etype='author')
subject_g = dgl.graph(data['PLP'], ntype='paper', etype='subject')
gs = [author_g, subject_g]
train_idx = torch.from_numpy(data['train_idx']).long().squeeze(0)
val_idx = torch.from_numpy(data['val_idx']).long().squeeze(0)
test_idx = torch.from_numpy(data['test_idx']).long().squeeze(0)
num_nodes = author_g.number_of_nodes()
train_mask = get_binary_mask(num_nodes, train_idx)
val_mask = get_binary_mask(num_nodes, val_idx)
test_mask = get_binary_mask(num_nodes, test_idx)
print('dataset loaded')
pprint({
'dataset': 'ACM',
'train': train_mask.sum().item() / num_nodes,
'val': val_mask.sum().item() / num_nodes,
'test': test_mask.sum().item() / num_nodes
})
return gs, features, labels, num_classes, train_idx, val_idx, test_idx, \
train_mask, val_mask, test_mask
def load_acm_raw(remove_self_loop):
assert not remove_self_loop
url = 'dataset/ACM.mat'
data_path = get_download_dir() + '/ACM.mat'
download(_get_dgl_url(url), path=data_path)
data = sio.loadmat(data_path)
p_vs_l = data['PvsL'] # paper-field?
p_vs_a = data['PvsA'] # paper-author
p_vs_t = data['PvsT'] # paper-term, bag of words
p_vs_c = data['PvsC'] # paper-conference, labels come from that
# We assign
# (1) KDD papers as class 0 (data mining),
# (2) SIGMOD and VLDB papers as class 1 (database),
# (3) SIGCOMM and MOBICOMM papers as class 2 (communication)
conf_ids = [0, 1, 9, 10, 13]
label_ids = [0, 1, 2, 2, 1]
p_vs_c_filter = p_vs_c[:, conf_ids]
p_selected = (p_vs_c_filter.sum(1) != 0).A1.nonzero()[0]
p_vs_l = p_vs_l[p_selected]
p_vs_a = p_vs_a[p_selected]
p_vs_t = p_vs_t[p_selected]
p_vs_c = p_vs_c[p_selected]
pa = dgl.bipartite(p_vs_a, 'paper', 'pa', 'author')
ap = dgl.bipartite(p_vs_a.transpose(), 'author', 'ap', 'paper')
pl = dgl.bipartite(p_vs_l, 'paper', 'pf', 'field')
lp = dgl.bipartite(p_vs_l.transpose(), 'field', 'fp', 'paper')
hg = dgl.hetero_from_relations([pa, ap, pl, lp])
features = torch.FloatTensor(p_vs_t.toarray())
pc_p, pc_c = p_vs_c.nonzero()
labels = np.zeros(len(p_selected), dtype=np.int64)
for conf_id, label_id in zip(conf_ids, label_ids):
labels[pc_p[pc_c == conf_id]] = label_id
labels = torch.LongTensor(labels)
num_classes = 3
float_mask = np.zeros(len(pc_p))
for conf_id in conf_ids:
pc_c_mask = (pc_c == conf_id)
float_mask[pc_c_mask] = np.random.permutation(np.linspace(0, 1, pc_c_mask.sum()))
train_idx = np.where(float_mask <= 0.2)[0]
val_idx = np.where((float_mask > 0.2) & (float_mask <= 0.3))[0]
test_idx = np.where(float_mask > 0.3)[0]
num_nodes = hg.number_of_nodes('paper')
train_mask = get_binary_mask(num_nodes, train_idx)
val_mask = get_binary_mask(num_nodes, val_idx)
test_mask = get_binary_mask(num_nodes, test_idx)
return hg, features, labels, num_classes, train_idx, val_idx, test_idx, \
train_mask, val_mask, test_mask
def load_data(dataset, remove_self_loop=False):
if dataset == 'ACM':
return load_acm(remove_self_loop)
elif dataset == 'ACMRaw':
return load_acm_raw(remove_self_loop)
else:
return NotImplementedError('Unsupported dataset {}'.format(dataset))
class EarlyStopping(object):
def __init__(self, patience=10):
dt = datetime.datetime.now()
self.filename = 'early_stop_{}_{:02d}-{:02d}-{:02d}.pth'.format(
dt.date(), dt.hour, dt.minute, dt.second)
self.patience = patience
self.counter = 0
self.best_acc = None
self.best_loss = None
self.early_stop = False
def step(self, loss, acc, model):
if self.best_loss is None:
self.best_acc = acc
self.best_loss = loss
self.save_checkpoint(model)
elif (loss > self.best_loss) and (acc < self.best_acc):
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
if (loss <= self.best_loss) and (acc >= self.best_acc):
self.save_checkpoint(model)
self.best_loss = np.min((loss, self.best_loss))
self.best_acc = np.max((acc, self.best_acc))
self.counter = 0
return self.early_stop
def save_checkpoint(self, model):
"""Saves model when validation loss decreases."""
torch.save(model.state_dict(), self.filename)
def load_checkpoint(self, model):
"""Load the latest checkpoint."""
model.load_state_dict(torch.load(self.filename))
......@@ -6,7 +6,7 @@ from contextlib import contextmanager
import networkx as nx
import dgl
from .base import ALL, is_all, DGLError
from .base import ALL, is_all, DGLError, dgl_warning
from . import backend as F
from . import init
from .frame import FrameRef, Frame, Scheme, sync_frame_initializer
......@@ -3034,7 +3034,7 @@ class DGLGraph(DGLBaseGraph):
sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes=preserve_nodes)
return subgraph.DGLSubGraph(self, sgi)
def adjacency_matrix_scipy(self, transpose=False, fmt='csr', return_edge_ids=None):
def adjacency_matrix_scipy(self, transpose=None, fmt='csr', return_edge_ids=None):
"""Return the scipy adjacency matrix representation of this graph.
By default, a row of returned adjacency matrix represents the destination
......@@ -3060,6 +3060,12 @@ class DGLGraph(DGLBaseGraph):
The scipy representation of adjacency matrix.
"""
if transpose is None:
dgl_warning(
"Currently adjacency_matrix() returns a matrix with destination as rows"
" by default. In 0.5 the result will have source as rows"
" (i.e. transpose=True)")
transpose = False
return self._graph.adjacency_matrix_scipy(transpose, fmt, return_edge_ids)
def adjacency_matrix(self, transpose=False, ctx=F.cpu()):
......@@ -3083,6 +3089,12 @@ class DGLGraph(DGLBaseGraph):
SparseTensor
The adjacency matrix.
"""
if transpose is None:
dgl_warning(
"Currently adjacency_matrix() returns a matrix with destination as rows"
" by default. In 0.5 the result will have source as rows"
" (i.e. transpose=True)")
transpose = False
return self._graph.adjacency_matrix(transpose, ctx)[0]
def incidence_matrix(self, typestr, ctx=F.cpu()):
......
......@@ -12,7 +12,7 @@ from . import init
from .runtime import ir, scheduler, Runtime, GraphAdapter
from .frame import Frame, FrameRef, frame_like, sync_frame_initializer
from .view import HeteroNodeView, HeteroNodeDataView, HeteroEdgeView, HeteroEdgeDataView
from .base import ALL, SLICE_FULL, NTYPE, NID, ETYPE, EID, is_all, DGLError
from .base import ALL, SLICE_FULL, NTYPE, NID, ETYPE, EID, is_all, DGLError, dgl_warning
__all__ = ['DGLHeteroGraph', 'combine_names']
......@@ -1403,6 +1403,13 @@ class DGLHeteroGraph(object):
SparseTensor or scipy.sparse.spmatrix
Adjacency matrix.
"""
if transpose is None:
dgl_warning(
"Currently adjacency_matrix() returns a matrix with destination as rows"
" by default. In 0.5 the result will have source as rows"
" (i.e. transpose=True)")
transpose = False
etid = self.get_etype_id(etype)
if scipy_fmt is None:
return self._graph.adjacency_matrix(etid, transpose, ctx)[0]
......
......@@ -7,11 +7,12 @@ from .graph import DGLGraph
from . import backend as F
from .graph_index import from_coo
from .batched_graph import BatchedDGLGraph, unbatch
from .convert import graph, bipartite
__all__ = ['line_graph', 'khop_adj', 'khop_graph', 'reverse', 'to_simple_graph', 'to_bidirected',
'laplacian_lambda_max', 'knn_graph', 'segmented_knn_graph', 'add_self_loop',
'remove_self_loop']
'remove_self_loop', 'metapath_reachable_graph']
def pairwise_squared_distance(x):
......@@ -404,6 +405,55 @@ def laplacian_lambda_max(g):
return_eigenvectors=False)[0].real)
return rst
def metapath_reachable_graph(g, metapath):
"""Return a graph where the successors of any node ``u`` are nodes reachable from ``u`` by
the given metapath.
If the beginning node type ``s`` and ending node type ``t`` are the same, it will return
a homogeneous graph with node type ``s = t``. Otherwise, a unidirectional bipartite graph
with source node type ``s`` and destination node type ``t`` is returned.
In both cases, two nodes ``u`` and ``v`` will be connected with an edge ``(u, v)`` if
there exists one path matching the metapath from ``u`` to ``v``.
The result graph keeps the node set of type ``s`` and ``t`` in the original graph even if
they might have no neighbor.
The features of the source/destination node type in the original graph would be copied to
the new graph.
Parameters
----------
g : DGLHeteroGraph
The input graph
metapath : list[str or tuple of str]
Metapath in the form of a list of edge types
Returns
-------
DGLHeteroGraph
A homogeneous or bipartite graph.
"""
adj = 1
for etype in metapath:
adj = adj * g.adj(etype=etype, scipy_fmt='csr', transpose=True)
adj = (adj != 0).tocsr()
srctype = g.to_canonical_etype(metapath[0])[0]
dsttype = g.to_canonical_etype(metapath[-1])[2]
if srctype == dsttype:
assert adj.shape[0] == adj.shape[1]
new_g = graph(adj, ntype=srctype)
else:
new_g = bipartite(adj, utype=srctype, vtype=dsttype)
for key, value in g.nodes[srctype].data.items():
new_g.nodes[srctype].data[key] = value
if srctype != dsttype:
for key, value in g.nodes[dsttype].data.items():
new_g.nodes[dsttype].data[key] = value
return new_g
def add_self_loop(g):
"""Return a new graph containing all the edges in the input graph plus self loops
......@@ -470,5 +520,4 @@ def remove_self_loop(g):
new_g.add_edges(src[non_self_edges_idx], dst[non_self_edges_idx])
return new_g
_init_api("dgl.transform")
......@@ -657,6 +657,23 @@ def test_convert():
g = dgl.to_homo(hg)
assert g.number_of_nodes() == 5
def test_transform():
g = create_test_heterograph()
x = F.randn((3, 5))
g.nodes['user'].data['h'] = x
new_g = dgl.metapath_reachable_graph(g, ['follows', 'plays'])
assert new_g.ntypes == ['user', 'game']
assert new_g.number_of_edges() == 3
assert F.asnumpy(new_g.has_edges_between([0, 0, 1], [0, 1, 1])).all()
new_g = dgl.metapath_reachable_graph(g, ['follows'])
assert new_g.ntypes == ['user']
assert new_g.number_of_edges() == 2
assert F.asnumpy(new_g.has_edges_between([0, 1], [1, 2])).all()
def test_subgraph():
g = create_test_heterograph()
x = F.randn((3, 5))
......@@ -1142,6 +1159,7 @@ if __name__ == '__main__':
test_view1()
test_flatten()
test_convert()
test_transform()
test_subgraph()
test_apply()
test_level1()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment