Unverified Commit dce89919 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Auto-reformat multiple python folders. (#5325)



* auto-reformat

* lintrunner

---------
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent ab812179
from ruamel.yaml.comments import CommentedMap
......@@ -14,8 +13,10 @@ def deep_convert_dict(layer):
return to_ret
import collections.abc
def merge_comment(d, comment_dict, column=30):
for k, v in comment_dict.items():
if isinstance(v, collections.abc.Mapping):
......
#!/usr/bin/env python
from setuptools import find_packages
from distutils.core import setup
setup(name='dglgo',
version='0.0.2',
description='DGL',
author='DGL Team',
author_email='wmjlyjemaine@gmail.com',
from setuptools import find_packages
setup(
name="dglgo",
version="0.0.2",
description="DGL",
author="DGL Team",
author_email="wmjlyjemaine@gmail.com",
packages=find_packages(),
install_requires=[
'typer>=0.4.0',
'isort>=5.10.1',
'autopep8>=1.6.0',
'numpydoc>=1.1.0',
"typer>=0.4.0",
"isort>=5.10.1",
"autopep8>=1.6.0",
"numpydoc>=1.1.0",
"pydantic>=1.9.0",
"ruamel.yaml>=0.17.20",
"PyYAML>=5.1",
"ogb>=1.3.3",
"rdkit-pypi",
"scikit-learn>=0.20.0"
"scikit-learn>=0.20.0",
],
package_data={"": ["./*"]},
include_package_data=True,
license='APACHE',
entry_points={
'console_scripts': [
"dgl = dglgo.cli.cli:main"
]
},
url='https://github.com/dmlc/dgl',
)
license="APACHE",
entry_points={"console_scripts": ["dgl = dglgo.cli.cli:main"]},
url="https://github.com/dmlc/dgl",
)
......@@ -14,16 +14,18 @@
#
import os
import sys
sys.path.insert(0, os.path.abspath('../../python'))
sys.path.insert(0, os.path.abspath("../../python"))
# -- Project information -----------------------------------------------------
project = 'DGL'
copyright = '2018, DGL Team'
author = 'DGL Team'
project = "DGL"
copyright = "2018, DGL Team"
author = "DGL Team"
import dgl
version = dgl.__version__
release = dgl.__version__
dglbackend = os.environ.get("DGLBACKEND", "pytorch")
......@@ -39,35 +41,35 @@ dglbackend = os.environ.get("DGLBACKEND", "pytorch")
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.coverage',
'sphinx.ext.mathjax',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx.ext.intersphinx',
'sphinx.ext.graphviz',
'sphinxemoji.sphinxemoji',
'sphinx_gallery.gen_gallery',
'sphinx_copybutton',
'nbsphinx',
'nbsphinx_link',
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.coverage",
"sphinx.ext.mathjax",
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinx.ext.intersphinx",
"sphinx.ext.graphviz",
"sphinxemoji.sphinxemoji",
"sphinx_gallery.gen_gallery",
"sphinx_copybutton",
"nbsphinx",
"nbsphinx_link",
]
# Do not run notebooks on non-pytorch backends
if dglbackend != "pytorch":
nbsphinx_execute = 'never'
nbsphinx_execute = "never"
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
source_suffix = ['.rst', '.md']
source_suffix = [".rst", ".md"]
# The master toctree document.
master_doc = 'index'
master_doc = "index"
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
......@@ -90,7 +92,7 @@ pygments_style = None
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'
html_theme = "sphinx_rtd_theme"
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
......@@ -101,8 +103,8 @@ html_theme = 'sphinx_rtd_theme'
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_css_files = ['css/custom.css']
html_static_path = ["_static"]
html_css_files = ["css/custom.css"]
# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
......@@ -118,7 +120,7 @@ html_css_files = ['css/custom.css']
# -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'dgldoc'
htmlhelp_basename = "dgldoc"
# -- Options for LaTeX output ------------------------------------------------
......@@ -127,15 +129,12 @@ latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
......@@ -145,8 +144,7 @@ latex_elements = {
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'dgl.tex', 'DGL Documentation',
'DGL Team', 'manual'),
(master_doc, "dgl.tex", "DGL Documentation", "DGL Team", "manual"),
]
......@@ -154,10 +152,7 @@ latex_documents = [
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'dgl', 'DGL Documentation',
[author], 1)
]
man_pages = [(master_doc, "dgl", "DGL Documentation", [author], 1)]
# -- Options for Texinfo output ----------------------------------------------
......@@ -166,9 +161,15 @@ man_pages = [
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'dgl', 'DGL Documentation',
author, 'dgl', 'Library for deep learning on graphs.',
'Miscellaneous'),
(
master_doc,
"dgl",
"DGL Documentation",
author,
"dgl",
"Library for deep learning on graphs.",
"Miscellaneous",
),
]
......@@ -187,64 +188,71 @@ epub_title = project
# epub_uid = ''
# A list of files that should not be packed into the epub file.
epub_exclude_files = ['search.html']
epub_exclude_files = ["search.html"]
# -- Extension configuration -------------------------------------------------
autosummary_generate = True
autodoc_member_order = 'alphabetical'
autodoc_member_order = "alphabetical"
intersphinx_mapping = {
'python': ('https://docs.python.org/{.major}'.format(sys.version_info), None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'scipy': ('http://docs.scipy.org/doc/scipy/reference', None),
'matplotlib': ('http://matplotlib.org/', None),
'networkx' : ('https://networkx.github.io/documentation/stable', None),
"python": (
"https://docs.python.org/{.major}".format(sys.version_info),
None,
),
"numpy": ("http://docs.scipy.org/doc/numpy/", None),
"scipy": ("http://docs.scipy.org/doc/scipy/reference", None),
"matplotlib": ("http://matplotlib.org/", None),
"networkx": ("https://networkx.github.io/documentation/stable", None),
}
# sphinx gallery configurations
from sphinx_gallery.sorting import FileNameSortKey
examples_dirs = ['../../tutorials/blitz',
'../../tutorials/large',
'../../tutorials/dist',
'../../tutorials/models',
'../../tutorials/multi',
'../../tutorials/cpu'] # path to find sources
gallery_dirs = ['tutorials/blitz/',
'tutorials/large/',
'tutorials/dist/',
'tutorials/models/',
'tutorials/multi/',
'tutorials/cpu'] # path to generate docs
examples_dirs = [
"../../tutorials/blitz",
"../../tutorials/large",
"../../tutorials/dist",
"../../tutorials/models",
"../../tutorials/multi",
"../../tutorials/cpu",
] # path to find sources
gallery_dirs = [
"tutorials/blitz/",
"tutorials/large/",
"tutorials/dist/",
"tutorials/models/",
"tutorials/multi/",
"tutorials/cpu",
] # path to generate docs
if dglbackend != "pytorch":
examples_dirs = []
gallery_dirs = []
reference_url = {
'dgl' : None,
'numpy': 'http://docs.scipy.org/doc/numpy/',
'scipy': 'http://docs.scipy.org/doc/scipy/reference',
'matplotlib': 'http://matplotlib.org/',
'networkx' : 'https://networkx.github.io/documentation/stable',
"dgl": None,
"numpy": "http://docs.scipy.org/doc/numpy/",
"scipy": "http://docs.scipy.org/doc/scipy/reference",
"matplotlib": "http://matplotlib.org/",
"networkx": "https://networkx.github.io/documentation/stable",
}
sphinx_gallery_conf = {
'backreferences_dir' : 'generated/backreferences',
'doc_module' : ('dgl', 'numpy'),
'examples_dirs' : examples_dirs,
'gallery_dirs' : gallery_dirs,
'within_subsection_order' : FileNameSortKey,
'filename_pattern' : '.py',
'download_all_examples' : False,
"backreferences_dir": "generated/backreferences",
"doc_module": ("dgl", "numpy"),
"examples_dirs": examples_dirs,
"gallery_dirs": gallery_dirs,
"within_subsection_order": FileNameSortKey,
"filename_pattern": ".py",
"download_all_examples": False,
}
# Compatibility for different backend when builds tutorials
if dglbackend == 'mxnet':
sphinx_gallery_conf['filename_pattern'] = "/*(?<=mx)\.py"
if dglbackend == 'pytorch':
sphinx_gallery_conf['filename_pattern'] = "/*(?<!mx)\.py"
if dglbackend == "mxnet":
sphinx_gallery_conf["filename_pattern"] = "/*(?<=mx)\.py"
if dglbackend == "pytorch":
sphinx_gallery_conf["filename_pattern"] = "/*(?<!mx)\.py"
# sphinx-copybutton tool
copybutton_prompt_text = r'>>> |\.\.\. '
copybutton_prompt_text = r">>> |\.\.\. "
copybutton_prompt_is_regexp = True
from pytablewriter import RstGridTableWriter, MarkdownTableWriter
import numpy as np
import pandas as pd
from dgl import DGLGraph
from dgl.data.gnn_benchmark import AmazonCoBuy, CoraFull, Coauthor
from dgl.data.karate import KarateClub
from dgl.data.gindt import GINDataset
# from dgl.data.qm9 import QM9
from dgl.data import CitationGraphDataset, PPIDataset, RedditDataset, TUDataset
from dgl.data.bitcoinotc import BitcoinOTC
from dgl.data.gdelt import GDELT
from dgl.data.gindt import GINDataset
from dgl.data.gnn_benchmark import AmazonCoBuy, Coauthor, CoraFull
from dgl.data.icews18 import ICEWS18
from dgl.data.karate import KarateClub
from dgl.data.qm7b import QM7b
# from dgl.data.qm9 import QM9
from dgl.data import CitationGraphDataset, PPIDataset, RedditDataset, TUDataset
from pytablewriter import MarkdownTableWriter, RstGridTableWriter
ds_list = {
"BitcoinOTC": "BitcoinOTC()",
......@@ -40,9 +41,9 @@ writer = RstGridTableWriter()
# writer = MarkdownTableWriter()
extract_graph = lambda g: g if isinstance(g, DGLGraph) else g[0]
stat_list=[]
for k,v in ds_list.items():
print(k, ' ', v)
stat_list = []
for k, v in ds_list.items():
print(k, " ", v)
ds = eval(v.split("/")[0])
num_nodes = []
num_edges = []
......@@ -58,10 +59,10 @@ for k,v in ds_list.items():
"# of graphs": len(ds),
"Avg. # of nodes": np.mean(num_nodes),
"Avg. # of edges": np.mean(num_edges),
"Node field": ', '.join(list(gg.ndata.keys())),
"Edge field": ', '.join(list(gg.edata.keys())),
"Node field": ", ".join(list(gg.ndata.keys())),
"Edge field": ", ".join(list(gg.edata.keys())),
# "Graph field": ', '.join(ds[0][0].gdata.keys()) if hasattr(ds[0][0], "gdata") else "",
"Temporal": hasattr(ds, "is_temporal")
"Temporal": hasattr(ds, "is_temporal"),
}
stat_list.append(dd)
......
......@@ -26,15 +26,14 @@ def get_sddmm_kernels_gpu(idtypes, dtypes):
return ret
if __name__ == '__main__':
binary_path = 'libfeatgraph_kernels.so'
if __name__ == "__main__":
binary_path = "libfeatgraph_kernels.so"
kernels = []
idtypes = ['int32', 'int64']
dtypes = ['float16', 'float64', 'float32', 'int32', 'int64']
idtypes = ["int32", "int64"]
dtypes = ["float16", "float64", "float32", "int32", "int64"]
kernels += get_sddmm_kernels_gpu(idtypes, dtypes)
# build kernels and export the module to libfeatgraph_kernels.so
module = tvm.build(kernels, target='cuda', target_host='llvm')
module = tvm.build(kernels, target="cuda", target_host="llvm")
module.export_library(binary_path)
......@@ -4,7 +4,7 @@ from tvm import te
def sddmm_tree_reduction_gpu(idx_type, feat_type):
""" SDDMM kernels on GPU optimized with Tree Reduction.
"""SDDMM kernels on GPU optimized with Tree Reduction.
Parameters
----------
......@@ -19,35 +19,40 @@ def sddmm_tree_reduction_gpu(idx_type, feat_type):
The result IRModule.
"""
# define vars and placeholders
nnz = te.var('nnz', idx_type)
num_rows = te.var('num_rows', idx_type)
num_cols = te.var('num_cols', idx_type)
H = te.var('num_heads', idx_type)
D = te.var('feat_len', idx_type)
row = te.placeholder((nnz,), idx_type, 'row')
col = te.placeholder((nnz,), idx_type, 'col')
ufeat = te.placeholder((num_rows, H, D), feat_type, 'ufeat')
vfeat = te.placeholder((num_cols, H, D), feat_type, 'vfeat')
nnz = te.var("nnz", idx_type)
num_rows = te.var("num_rows", idx_type)
num_cols = te.var("num_cols", idx_type)
H = te.var("num_heads", idx_type)
D = te.var("feat_len", idx_type)
row = te.placeholder((nnz,), idx_type, "row")
col = te.placeholder((nnz,), idx_type, "col")
ufeat = te.placeholder((num_rows, H, D), feat_type, "ufeat")
vfeat = te.placeholder((num_cols, H, D), feat_type, "vfeat")
# define edge computation function
def edge_func(eid, h, i):
k = te.reduce_axis((0, D), name='k')
k = te.reduce_axis((0, D), name="k")
return te.sum(ufeat[row[eid], h, k] * vfeat[col[eid], h, k], axis=k)
out = te.compute((nnz, H, tvm.tir.IntImm(idx_type, 1)), edge_func, name='out')
out = te.compute(
(nnz, H, tvm.tir.IntImm(idx_type, 1)), edge_func, name="out"
)
# define schedules
sched = te.create_schedule(out.op)
edge_axis, head_axis, _ = out.op.axis
reduce_axis = out.op.reduce_axis[0]
_, red_inner = sched[out].split(reduce_axis, factor=32)
edge_outer, edge_inner = sched[out].split(edge_axis, factor=32)
sched[out].bind(red_inner, te.thread_axis('threadIdx.x'))
sched[out].bind(edge_inner, te.thread_axis('threadIdx.y'))
sched[out].bind(edge_outer, te.thread_axis('blockIdx.x'))
sched[out].bind(head_axis, te.thread_axis('blockIdx.y'))
return tvm.lower(sched, [row, col, ufeat, vfeat, out],
name='SDDMMTreeReduction_{}_{}'.format(idx_type, feat_type))
sched[out].bind(red_inner, te.thread_axis("threadIdx.x"))
sched[out].bind(edge_inner, te.thread_axis("threadIdx.y"))
sched[out].bind(edge_outer, te.thread_axis("blockIdx.x"))
sched[out].bind(head_axis, te.thread_axis("blockIdx.y"))
return tvm.lower(
sched,
[row, col, ufeat, vfeat, out],
name="SDDMMTreeReduction_{}_{}".format(idx_type, feat_type),
)
if __name__ == '__main__':
kernel0 = sddmm_tree_reduction_gpu('int32', 'float32')
if __name__ == "__main__":
kernel0 = sddmm_tree_reduction_gpu("int32", "float32")
print(kernel0)
import torch
import dgl
import dgl.backend as F
import torch
g = dgl.rand_graph(10, 15).int().to(torch.device(0))
gidx = g._graph
u = torch.rand((10,2,8), device=torch.device(0))
v = torch.rand((10,2,8), device=torch.device(0))
e = dgl.ops.gsddmm(g, 'dot', u, v)
u = torch.rand((10, 2, 8), device=torch.device(0))
v = torch.rand((10, 2, 8), device=torch.device(0))
e = dgl.ops.gsddmm(g, "dot", u, v)
print(e)
e = torch.zeros((15,2,1), device=torch.device(0))
e = torch.zeros((15, 2, 1), device=torch.device(0))
u = F.zerocopy_to_dgl_ndarray(u)
v = F.zerocopy_to_dgl_ndarray(v)
e = F.zerocopy_to_dgl_ndarray_for_write(e)
......
......@@ -22,13 +22,13 @@ networks with PyTorch.
"""
import os
os.environ['DGLBACKEND'] = 'pytorch'
import torch
import torch.nn as nn
import torch.nn.functional as F
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F
######################################################################
# Overview of Node Classification with GNN
......
......@@ -31,11 +31,11 @@ By the end of this tutorial you will be able to:
#
import os
os.environ['DGLBACKEND'] = 'pytorch'
import numpy as np
import torch
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import numpy as np
import torch
g = dgl.graph(([0, 0, 0, 0, 0], [1, 2, 3, 4, 5]), num_nodes=6)
# Equivalently, PyTorch LongTensors also work.
......
......@@ -19,13 +19,13 @@ GNN for node classification <1_introduction>`.
"""
import os
os.environ['DGLBACKEND'] = 'pytorch'
import torch
import torch.nn as nn
import torch.nn.functional as F
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
######################################################################
# Message passing and GNNs
......
......@@ -19,17 +19,17 @@ By the end of this tutorial you will be able to
import itertools
import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.data
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.data
######################################################################
# Overview of Link Prediction with GNN
# ------------------------------------
......
......@@ -14,13 +14,13 @@ By the end of this tutorial, you will be able to
"""
import os
os.environ['DGLBACKEND'] = 'pytorch'
import torch
import torch.nn as nn
import torch.nn.functional as F
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F
######################################################################
# Overview of Graph Classification with GNN
......@@ -54,6 +54,8 @@ print("Node feature dimensionality:", dataset.dim_nfeats)
print("Number of graph categories:", dataset.gclasses)
from dgl.dataloading import GraphDataLoader
######################################################################
# Defining Data Loader
# --------------------
......@@ -74,8 +76,6 @@ print("Number of graph categories:", dataset.gclasses)
from torch.utils.data.sampler import SubsetRandomSampler
from dgl.dataloading import GraphDataLoader
num_examples = len(dataset)
num_train = int(num_examples * 0.8)
......
......@@ -88,10 +88,10 @@ interactions.head()
#
import os
os.environ['DGLBACKEND'] = 'pytorch'
import torch
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import torch
from dgl.data import DGLDataset
......
......@@ -26,10 +26,11 @@ Sampling for GNN Training <L0_neighbor_sampling_overview>`.
#
import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import torch
import numpy as np
import torch
from ogb.nodeproppred import DglNodePropPredDataset
dataset = DglNodePropPredDataset("ogbn-arxiv")
......@@ -284,13 +285,14 @@ valid_dataloader = dgl.dataloading.DataLoader(
)
import sklearn.metrics
######################################################################
# The following is a training loop that performs validation every epoch.
# It also saves the model with the best validation accuracy into a file.
#
import tqdm
import sklearn.metrics
best_accuracy = 0
best_model_path = "model.pt"
......
......@@ -53,10 +53,11 @@ Sampling for Node Classification <L1_large_node_classification>`.
#
import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import torch
import numpy as np
import torch
from ogb.nodeproppred import DglNodePropPredDataset
dataset = DglNodePropPredDataset("ogbn-arxiv")
......@@ -339,6 +340,8 @@ predictor = DotPredictor().to(device)
opt = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()))
import sklearn.metrics
######################################################################
# The following is the training loop for link prediction and
# evaluation, and also saves the model that performs the best on the
......@@ -346,7 +349,6 @@ opt = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()))
#
import tqdm
import sklearn.metrics
best_accuracy = 0
best_model_path = "model.pt"
......
......@@ -14,30 +14,33 @@ for stochastic GNN training. It assumes that
"""
import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import torch
import numpy as np
import torch
from ogb.nodeproppred import DglNodePropPredDataset
dataset = DglNodePropPredDataset('ogbn-arxiv')
device = 'cpu' # change to 'cuda' for GPU
dataset = DglNodePropPredDataset("ogbn-arxiv")
device = "cpu" # change to 'cuda' for GPU
graph, node_labels = dataset[0]
# Add reverse edges since ogbn-arxiv is unidirectional.
graph = dgl.add_reverse_edges(graph)
graph.ndata['label'] = node_labels[:, 0]
graph.ndata["label"] = node_labels[:, 0]
idx_split = dataset.get_idx_split()
train_nids = idx_split['train']
node_features = graph.ndata['feat']
train_nids = idx_split["train"]
node_features = graph.ndata["feat"]
sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4])
train_dataloader = dgl.dataloading.DataLoader(
graph, train_nids, sampler,
graph,
train_nids,
sampler,
batch_size=1024,
shuffle=True,
drop_last=False,
num_workers=0
num_workers=0,
)
input_nodes, output_nodes, mfgs = next(iter(train_dataloader))
......@@ -75,8 +78,8 @@ print(mfg.num_src_nodes(), mfg.num_dst_nodes())
# will do with ``ndata`` on the graphs you have seen earlier:
#
mfg.srcdata['x'] = torch.zeros(mfg.num_src_nodes(), mfg.num_dst_nodes())
dst_feat = mfg.dstdata['feat']
mfg.srcdata["x"] = torch.zeros(mfg.num_src_nodes(), mfg.num_dst_nodes())
dst_feat = mfg.dstdata["feat"]
######################################################################
......@@ -105,7 +108,11 @@ mfg.srcdata[dgl.NID], mfg.dstdata[dgl.NID]
# .. |image1| image:: https://data.dgl.ai/tutorial/img/bipartite.gif
#
print(torch.equal(mfg.srcdata[dgl.NID][:mfg.num_dst_nodes()], mfg.dstdata[dgl.NID]))
print(
torch.equal(
mfg.srcdata[dgl.NID][: mfg.num_dst_nodes()], mfg.dstdata[dgl.NID]
)
)
######################################################################
......@@ -113,7 +120,7 @@ print(torch.equal(mfg.srcdata[dgl.NID][:mfg.num_dst_nodes()], mfg.dstdata[dgl.NI
# :math:`h_u^{(l-1)}`:
#
mfg.srcdata['h'] = torch.randn(mfg.num_src_nodes(), 10)
mfg.srcdata["h"] = torch.randn(mfg.num_src_nodes(), 10)
######################################################################
......@@ -132,8 +139,8 @@ mfg.srcdata['h'] = torch.randn(mfg.num_src_nodes(), 10)
import dgl.function as fn
mfg.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h'))
m_v = mfg.dstdata['h']
mfg.update_all(message_func=fn.copy_u("h", "m"), reduce_func=fn.mean("m", "h"))
m_v = mfg.dstdata["h"]
m_v
......@@ -147,6 +154,7 @@ import torch.nn as nn
import torch.nn.functional as F
import tqdm
class SAGEConv(nn.Module):
"""Graph convolution module used by the GraphSAGE model.
......@@ -157,6 +165,7 @@ class SAGEConv(nn.Module):
out_feat : int
Output feature size.
"""
def __init__(self, in_feat, out_feat):
super(SAGEConv, self).__init__()
# A linear submodule for projecting the input and neighbor feature to the output.
......@@ -174,14 +183,15 @@ class SAGEConv(nn.Module):
"""
with g.local_scope():
h_src, h_dst = h
g.srcdata['h'] = h_src # <---
g.dstdata['h'] = h_dst # <---
g.srcdata["h"] = h_src # <---
g.dstdata["h"] = h_dst # <---
# update_all is a message passing API.
g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_N'))
h_N = g.dstdata['h_N']
g.update_all(fn.copy_u("h", "m"), fn.mean("m", "h_N"))
h_N = g.dstdata["h_N"]
h_total = torch.cat([h_dst, h_N], dim=1) # <---
return self.linear(h_total)
class Model(nn.Module):
def __init__(self, in_feats, h_feats, num_classes):
super(Model, self).__init__()
......@@ -189,28 +199,31 @@ class Model(nn.Module):
self.conv2 = SAGEConv(h_feats, num_classes)
def forward(self, mfgs, x):
h_dst = x[:mfgs[0].num_dst_nodes()]
h_dst = x[: mfgs[0].num_dst_nodes()]
h = self.conv1(mfgs[0], (x, h_dst))
h = F.relu(h)
h_dst = h[:mfgs[1].num_dst_nodes()]
h_dst = h[: mfgs[1].num_dst_nodes()]
h = self.conv2(mfgs[1], (h, h_dst))
return h
sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4])
train_dataloader = dgl.dataloading.DataLoader(
graph, train_nids, sampler,
graph,
train_nids,
sampler,
device=device,
batch_size=1024,
shuffle=True,
drop_last=False,
num_workers=0
num_workers=0,
)
model = Model(graph.ndata['feat'].shape[1], 128, dataset.num_classes).to(device)
model = Model(graph.ndata["feat"].shape[1], 128, dataset.num_classes).to(device)
with tqdm.tqdm(train_dataloader) as tq:
for step, (input_nodes, output_nodes, mfgs) in enumerate(tq):
inputs = mfgs[0].srcdata['feat']
labels = mfgs[-1].dstdata['label']
inputs = mfgs[0].srcdata["feat"]
labels = mfgs[-1].dstdata["label"]
predictions = model(mfgs, inputs)
......@@ -232,6 +245,7 @@ with tqdm.tqdm(train_dataloader) as tq:
# Say you start with a GNN module that works for full-graph training only:
#
class SAGEConv(nn.Module):
"""Graph convolution module used by the GraphSAGE model.
......@@ -242,6 +256,7 @@ class SAGEConv(nn.Module):
out_feat : int
Output feature size.
"""
def __init__(self, in_feat, out_feat):
super().__init__()
# A linear submodule for projecting the input and neighbor feature to the output.
......@@ -258,10 +273,13 @@ class SAGEConv(nn.Module):
The input node feature.
"""
with g.local_scope():
g.ndata['h'] = h
g.ndata["h"] = h
# update_all is a message passing API.
g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N'))
h_N = g.ndata['h_N']
g.update_all(
message_func=fn.copy_u("h", "m"),
reduce_func=fn.mean("m", "h_N"),
)
h_N = g.ndata["h_N"]
h_total = torch.cat([h, h_N], dim=1)
return self.linear(h_total)
......@@ -352,6 +370,7 @@ class SAGEConv(nn.Module):
# to something like the following:
#
class SAGEConvForBoth(nn.Module):
"""Graph convolution module used by the GraphSAGE model.
......@@ -362,6 +381,7 @@ class SAGEConvForBoth(nn.Module):
out_feat : int
Output feature size.
"""
def __init__(self, in_feat, out_feat):
super().__init__()
# A linear submodule for projecting the input and neighbor feature to the output.
......@@ -383,10 +403,13 @@ class SAGEConvForBoth(nn.Module):
else:
h_src = h_dst = h
g.srcdata['h'] = h_src
g.srcdata["h"] = h_src
# update_all is a message passing API.
g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N'))
h_N = g.ndata['h_N']
g.update_all(
message_func=fn.copy_u("h", "m"),
reduce_func=fn.mean("m", "h_N"),
)
h_N = g.ndata["h_N"]
h_total = torch.cat([h_dst, h_N], dim=1)
return self.linear(h_total)
......
......@@ -46,13 +46,13 @@ message passing APIs.
# representations :math:`h_v`, we can simply use builtin functions:
import os
os.environ['DGLBACKEND'] = 'pytorch'
import torch as th
import torch.nn as nn
import torch.nn.functional as F
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
gcn_msg = fn.copy_u(u="h", out="m")
......@@ -156,7 +156,6 @@ dur = []
for epoch in range(50):
if epoch >= 3:
t0 = time.time()
net.train()
logits = net(g, features)
logp = F.log_softmax(logits, 1)
......@@ -168,14 +167,12 @@ for epoch in range(50):
if epoch >= 3:
dur.append(time.time() - t0)
acc = evaluate(net, g, features, labels, test_mask)
print(
"Epoch {:05d} | Loss {:.4f} | Test Acc {:.4f} | Time(s) {:.4f}".format(
epoch, loss.item(), acc, np.mean(dur)
)
)
###############################################################################
# .. _math:
#
......
......@@ -137,18 +137,29 @@ multiple edges among any given pair.
#
import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
from functools import partial
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
import dgl.function as fn
from functools import partial
class RGCNLayer(nn.Module):
def __init__(self, in_feat, out_feat, num_rels, num_bases=-1, bias=None,
activation=None, is_input_layer=False):
def __init__(
self,
in_feat,
out_feat,
num_rels,
num_bases=-1,
bias=None,
activation=None,
is_input_layer=False,
):
super(RGCNLayer, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
......@@ -161,69 +172,85 @@ class RGCNLayer(nn.Module):
# sanity check
if self.num_bases <= 0 or self.num_bases > self.num_rels:
self.num_bases = self.num_rels
# weight bases in equation (3)
self.weight = nn.Parameter(torch.Tensor(self.num_bases, self.in_feat,
self.out_feat))
self.weight = nn.Parameter(
torch.Tensor(self.num_bases, self.in_feat, self.out_feat)
)
if self.num_bases < self.num_rels:
# linear combination coefficients in equation (3)
self.w_comp = nn.Parameter(torch.Tensor(self.num_rels, self.num_bases))
self.w_comp = nn.Parameter(
torch.Tensor(self.num_rels, self.num_bases)
)
# add bias
if self.bias:
self.bias = nn.Parameter(torch.Tensor(out_feat))
# init trainable parameters
nn.init.xavier_uniform_(self.weight,
gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(
self.weight, gain=nn.init.calculate_gain("relu")
)
if self.num_bases < self.num_rels:
nn.init.xavier_uniform_(self.w_comp,
gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(
self.w_comp, gain=nn.init.calculate_gain("relu")
)
if self.bias:
nn.init.xavier_uniform_(self.bias,
gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(
self.bias, gain=nn.init.calculate_gain("relu")
)
def forward(self, g):
if self.num_bases < self.num_rels:
# generate all weights from bases (equation (3))
weight = self.weight.view(self.in_feat, self.num_bases, self.out_feat)
weight = torch.matmul(self.w_comp, weight).view(self.num_rels,
self.in_feat, self.out_feat)
weight = self.weight.view(
self.in_feat, self.num_bases, self.out_feat
)
weight = torch.matmul(self.w_comp, weight).view(
self.num_rels, self.in_feat, self.out_feat
)
else:
weight = self.weight
if self.is_input_layer:
def message_func(edges):
# for input layer, matrix multiply can be converted to be
# an embedding lookup using source node id
embed = weight.view(-1, self.out_feat)
index = edges.data[dgl.ETYPE] * self.in_feat + edges.src['id']
return {'msg': embed[index] * edges.data['norm']}
index = edges.data[dgl.ETYPE] * self.in_feat + edges.src["id"]
return {"msg": embed[index] * edges.data["norm"]}
else:
def message_func(edges):
w = weight[edges.data[dgl.ETYPE]]
msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze()
msg = msg * edges.data['norm']
return {'msg': msg}
msg = torch.bmm(edges.src["h"].unsqueeze(1), w).squeeze()
msg = msg * edges.data["norm"]
return {"msg": msg}
def apply_func(nodes):
h = nodes.data['h']
h = nodes.data["h"]
if self.bias:
h = h + self.bias
if self.activation:
h = self.activation(h)
return {'h': h}
return {"h": h}
g.update_all(message_func, fn.sum(msg='msg', out='h'), apply_func)
g.update_all(message_func, fn.sum(msg="msg", out="h"), apply_func)
###############################################################################
# Full R-GCN model defined
# ~~~~~~~~~~~~~~~~~~~~~~~
class Model(nn.Module):
def __init__(self, num_nodes, h_dim, out_dim, num_rels,
num_bases=-1, num_hidden_layers=1):
def __init__(
self,
num_nodes,
h_dim,
out_dim,
num_rels,
num_bases=-1,
num_hidden_layers=1,
):
super(Model, self).__init__()
self.num_nodes = num_nodes
self.h_dim = h_dim
......@@ -257,23 +284,40 @@ class Model(nn.Module):
return features
def build_input_layer(self):
return RGCNLayer(self.num_nodes, self.h_dim, self.num_rels, self.num_bases,
activation=F.relu, is_input_layer=True)
return RGCNLayer(
self.num_nodes,
self.h_dim,
self.num_rels,
self.num_bases,
activation=F.relu,
is_input_layer=True,
)
def build_hidden_layer(self):
return RGCNLayer(self.h_dim, self.h_dim, self.num_rels, self.num_bases,
activation=F.relu)
return RGCNLayer(
self.h_dim,
self.h_dim,
self.num_rels,
self.num_bases,
activation=F.relu,
)
def build_output_layer(self):
return RGCNLayer(self.h_dim, self.out_dim, self.num_rels, self.num_bases,
activation=partial(F.softmax, dim=1))
return RGCNLayer(
self.h_dim,
self.out_dim,
self.num_rels,
self.num_bases,
activation=partial(F.softmax, dim=1),
)
def forward(self, g):
if self.features is not None:
g.ndata['id'] = self.features
g.ndata["id"] = self.features
for layer in self.layers:
layer(g)
return g.ndata.pop('h')
return g.ndata.pop("h")
###############################################################################
# Handle dataset
......@@ -284,16 +328,16 @@ class Model(nn.Module):
dataset = dgl.data.rdf.AIFBDataset()
g = dataset[0]
category = dataset.predict_category
train_mask = g.nodes[category].data.pop('train_mask')
test_mask = g.nodes[category].data.pop('test_mask')
train_mask = g.nodes[category].data.pop("train_mask")
test_mask = g.nodes[category].data.pop("test_mask")
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
labels = g.nodes[category].data.pop('label')
labels = g.nodes[category].data.pop("label")
num_rels = len(g.canonical_etypes)
num_classes = dataset.num_classes
# normalization factor
for cetype in g.canonical_etypes:
g.edges[cetype].data['norm'] = dgl.norm_by_dst(g, cetype).unsqueeze(1)
g.edges[cetype].data["norm"] = dgl.norm_by_dst(g, cetype).unsqueeze(1)
category_id = g.ntypes.index(category)
###############################################################################
......@@ -309,17 +353,19 @@ lr = 0.01 # learning rate
l2norm = 0 # L2 norm coefficient
# create graph
g = dgl.to_homogeneous(g, edata=['norm'])
g = dgl.to_homogeneous(g, edata=["norm"])
node_ids = torch.arange(g.num_nodes())
target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
# create model
model = Model(g.num_nodes(),
model = Model(
g.num_nodes(),
n_hidden,
num_classes,
num_rels,
num_bases=n_bases,
num_hidden_layers=n_hidden_layers)
num_hidden_layers=n_hidden_layers,
)
###############################################################################
# Training loop
......@@ -344,12 +390,15 @@ for epoch in range(n_epochs):
val_loss = F.cross_entropy(logits[test_idx], labels[test_idx])
val_acc = torch.sum(logits[test_idx].argmax(dim=1) == labels[test_idx])
val_acc = val_acc.item() / len(test_idx)
print("Epoch {:05d} | ".format(epoch) +
"Train Accuracy: {:.4f} | Train Loss: {:.4f} | ".format(
train_acc, loss.item()) +
"Validation Accuracy: {:.4f} | Validation loss: {:.4f}".format(
val_acc, val_loss.item()))
print(
"Epoch {:05d} | ".format(epoch)
+ "Train Accuracy: {:.4f} | Train Loss: {:.4f} | ".format(
train_acc, loss.item()
)
+ "Validation Accuracy: {:.4f} | Validation loss: {:.4f}".format(
val_acc, val_loss.item()
)
)
###############################################################################
# .. _link-prediction:
#
......
......@@ -87,19 +87,19 @@ Line Graph Neural Network
# than inter-class.
import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import torch
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.data import citation_graph as citegrh
data = citegrh.load_cora()
G = data[0]
labels = th.tensor(G.ndata['label'])
labels = th.tensor(G.ndata["label"])
# find all the nodes labeled with class 0
label0_nodes = th.nonzero(labels == 0, as_tuple=False).squeeze()
......@@ -108,7 +108,9 @@ src, _ = G.in_edges(label0_nodes)
src_labels = labels[src]
# find all the edges whose both endpoints are in class 0
intra_src = th.nonzero(src_labels == 0, as_tuple=False)
print('Intra-class edges percent: %.4f' % (len(intra_src) / len(src_labels)))
print("Intra-class edges percent: %.4f" % (len(intra_src) / len(src_labels)))
import matplotlib.pyplot as plt
###########################################################################################
# Binary community subgraph from Cora with a test dataset
......@@ -127,19 +129,30 @@ print('Intra-class edges percent: %.4f' % (len(intra_src) / len(src_labels)))
# With the following code, you can visualize one of the training samples and its community structure.
import networkx as nx
import matplotlib.pyplot as plt
train_set = dgl.data.CoraBinary()
G1, pmpd1, label1 = train_set[1]
nx_G1 = G1.to_networkx()
def visualize(labels, g):
pos = nx.spring_layout(g, seed=1)
plt.figure(figsize=(8, 8))
plt.axis('off')
nx.draw_networkx(g, pos=pos, node_size=50, cmap=plt.get_cmap('coolwarm'),
node_color=labels, edge_color='k',
arrows=False, width=0.5, style='dotted', with_labels=False)
plt.axis("off")
nx.draw_networkx(
g,
pos=pos,
node_size=50,
cmap=plt.get_cmap("coolwarm"),
node_color=labels,
edge_color="k",
arrows=False,
width=0.5,
style="dotted",
with_labels=False,
)
visualize(label1, nx_G1)
###########################################################################################
......@@ -369,20 +382,23 @@ visualize(label1, nx_G1)
# Return a list containing features gathered from multiple radius.
import dgl.function as fn
def aggregate_radius(radius, g, z):
# initializing list to collect message passing result
z_list = []
g.ndata['z'] = z
g.ndata["z"] = z
# pulling message from 1-hop neighbourhood
g.update_all(fn.copy_u(u='z', out='m'), fn.sum(msg='m', out='z'))
z_list.append(g.ndata['z'])
g.update_all(fn.copy_u(u="z", out="m"), fn.sum(msg="m", out="z"))
z_list.append(g.ndata["z"])
for i in range(radius - 1):
for j in range(2 ** i):
#pulling message from 2^j neighborhood
g.update_all(fn.copy_u(u='z', out='m'), fn.sum(msg='m', out='z'))
z_list.append(g.ndata['z'])
for j in range(2**i):
# pulling message from 2^j neighborhood
g.update_all(fn.copy_u(u="z", out="m"), fn.sum(msg="m", out="z"))
z_list.append(g.ndata["z"])
return z_list
#########################################################################
# Implementing :math:`\text{fuse}` as sparse matrix multiplication
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......@@ -428,7 +444,8 @@ class LGNNCore(nn.Module):
self.linear_prev = nn.Linear(in_feats, out_feats)
self.linear_deg = nn.Linear(in_feats, out_feats)
self.linear_radius = nn.ModuleList(
[nn.Linear(in_feats, out_feats) for i in range(radius)])
[nn.Linear(in_feats, out_feats) for i in range(radius)]
)
self.linear_fuse = nn.Linear(in_feats, out_feats)
self.bn = nn.BatchNorm1d(out_feats)
......@@ -442,7 +459,9 @@ class LGNNCore(nn.Module):
# aggregate 2^j-hop features
hop2j_list = aggregate_radius(self.radius, g, feat_a)
# apply linear transformation
hop2j_list = [linear(x) for linear, x in zip(self.linear_radius, hop2j_list)]
hop2j_list = [
linear(x) for linear, x in zip(self.linear_radius, hop2j_list)
]
radius_proj = sum(hop2j_list)
# term "fuse"
......@@ -458,6 +477,7 @@ class LGNNCore(nn.Module):
return result
##############################################################################################################
# Chain-up LGNN abstractions as an LGNN layer
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......@@ -482,6 +502,7 @@ class LGNNLayer(nn.Module):
next_lg_x = self.lg_layer(lg, lg_x, x, deg_lg, pm_pd_y)
return next_x, next_lg_x
########################################################################################
# Chain-up LGNN layers
# ~~~~~~~~~~~~~~~~~~~~
......@@ -504,15 +525,17 @@ class LGNN(nn.Module):
x, lg_x = self.layer2(g, lg, x, lg_x, deg_g, deg_lg, pm_pd)
x, lg_x = self.layer3(g, lg, x, lg_x, deg_g, deg_lg, pm_pd)
return self.linear(x)
#########################################################################################
# Training and inference
# -----------------------
# First load the data.
from torch.utils.data import DataLoader
training_loader = DataLoader(train_set,
batch_size=1,
collate_fn=train_set.collate_fn,
drop_last=True)
training_loader = DataLoader(
train_set, batch_size=1, collate_fn=train_set.collate_fn, drop_last=True
)
#######################################################################################
# Next, define the main training loop. Note that each training sample contains
......@@ -536,9 +559,12 @@ optimizer = th.optim.Adam(model.parameters(), lr=1e-2)
def sparse2th(mat):
value = mat.data
indices = th.LongTensor([mat.row, mat.col])
tensor = th.sparse.FloatTensor(indices, th.from_numpy(value).float(), mat.shape)
tensor = th.sparse.FloatTensor(
indices, th.from_numpy(value).float(), mat.shape
)
return tensor
# Train for 20 epochs
for i in range(20):
all_loss = []
......@@ -571,11 +597,11 @@ for i in range(20):
optimizer.zero_grad()
loss.backward()
optimizer.step()
niters = len(all_loss)
print("Epoch %d | loss %.4f | accuracy %.4f" % (i,
sum(all_loss) / niters, sum(all_acc) / niters))
print(
"Epoch %d | loss %.4f | accuracy %.4f"
% (i, sum(all_loss) / niters, sum(all_acc) / niters)
)
#######################################################################################
# Visualize training progress
# -----------------------------
......@@ -613,6 +639,7 @@ visualize(label1, nx_G1)
# :math`\{Pm,Pd\}` as block diagonal matrix in correspondence to DGL batched
# graph API.
def collate_fn(batch):
graphs, pmpds, labels = zip(*batch)
batched_graphs = dgl.batch(graphs)
......@@ -620,6 +647,7 @@ def collate_fn(batch):
batched_labels = np.concatenate(labels, axis=0)
return batched_graphs, batched_pmpds, batched_labels
######################################################################################
# You can find the complete code on Github at
# `Community Detection with Graph Neural Networks (CDGNN) <https://github.com/dmlc/dgl/tree/master/examples/pytorch/line_graph>`_.
......@@ -105,9 +105,8 @@ structure-free normalization, in the style of attention.
# subpackage. Simply import the ``GATConv`` as the follows.
import os
os.environ['DGLBACKEND'] = 'pytorch'
from dgl.nn.pytorch import GATConv
os.environ["DGLBACKEND"] = "pytorch"
###############################################################
# Readers can skip the following step-by-step explanation of the implementation and
# jump to the `Put everything together`_ for training and visualization results.
......@@ -125,6 +124,7 @@ from dgl.nn.pytorch import GATConv
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GATConv
class GATLayer(nn.Module):
......@@ -139,37 +139,38 @@ class GATLayer(nn.Module):
def reset_parameters(self):
"""Reinitialize learnable parameters."""
gain = nn.init.calculate_gain('relu')
gain = nn.init.calculate_gain("relu")
nn.init.xavier_normal_(self.fc.weight, gain=gain)
nn.init.xavier_normal_(self.attn_fc.weight, gain=gain)
def edge_attention(self, edges):
# edge UDF for equation (2)
z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
z2 = torch.cat([edges.src["z"], edges.dst["z"]], dim=1)
a = self.attn_fc(z2)
return {'e': F.leaky_relu(a)}
return {"e": F.leaky_relu(a)}
def message_func(self, edges):
# message UDF for equation (3) & (4)
return {'z': edges.src['z'], 'e': edges.data['e']}
return {"z": edges.src["z"], "e": edges.data["e"]}
def reduce_func(self, nodes):
# reduce UDF for equation (3) & (4)
# equation (3)
alpha = F.softmax(nodes.mailbox['e'], dim=1)
alpha = F.softmax(nodes.mailbox["e"], dim=1)
# equation (4)
h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
return {'h': h}
h = torch.sum(alpha * nodes.mailbox["z"], dim=1)
return {"h": h}
def forward(self, h):
# equation (1)
z = self.fc(h)
self.g.ndata['z'] = z
self.g.ndata["z"] = z
# equation (2)
self.g.apply_edges(self.edge_attention)
# equation (3) & (4)
self.g.update_all(self.message_func, self.reduce_func)
return self.g.ndata.pop('h')
return self.g.ndata.pop("h")
##################################################################
# Equation (1)
......@@ -195,11 +196,13 @@ class GATLayer(nn.Module):
# ``apply_edges`` API. The argument to the ``apply_edges`` is an **Edge UDF**,
# which is defined as below:
def edge_attention(self, edges):
# edge UDF for equation (2)
z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
z2 = torch.cat([edges.src["z"], edges.dst["z"]], dim=1)
a = self.attn_fc(z2)
return {'e' : F.leaky_relu(a)}
return {"e": F.leaky_relu(a)}
########################################################################3
# Here, the dot product with the learnable weight vector :math:`\vec{a^{(l)}}`
......@@ -229,13 +232,15 @@ def edge_attention(self, edges):
# Both tasks first fetch data from the mailbox and then manipulate it on the
# second dimension (``dim=1``), on which the messages are batched.
def reduce_func(self, nodes):
# reduce UDF for equation (3) & (4)
# equation (3)
alpha = F.softmax(nodes.mailbox['e'], dim=1)
alpha = F.softmax(nodes.mailbox["e"], dim=1)
# equation (4)
h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
return {'h' : h}
h = torch.sum(alpha * nodes.mailbox["z"], dim=1)
return {"h": h}
#####################################################################
# Multi-head attention
......@@ -258,8 +263,9 @@ def reduce_func(self, nodes):
# Use the above defined single-head ``GATLayer`` as the building block
# for the ``MultiHeadGATLayer`` below:
class MultiHeadGATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
def __init__(self, g, in_dim, out_dim, num_heads, merge="cat"):
super(MultiHeadGATLayer, self).__init__()
self.heads = nn.ModuleList()
for i in range(num_heads):
......@@ -268,19 +274,21 @@ class MultiHeadGATLayer(nn.Module):
def forward(self, h):
head_outs = [attn_head(h) for attn_head in self.heads]
if self.merge == 'cat':
if self.merge == "cat":
# concat on the output feature dimension (dim=1)
return torch.cat(head_outs, dim=1)
else:
# merge using average
return torch.mean(torch.stack(head_outs))
###########################################################################
# Put everything together
# ^^^^^^^^^^^^^^^^^^^^^^^
#
# Now, you can define a two-layer GAT model.
class GAT(nn.Module):
def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
super(GAT, self).__init__()
......@@ -296,33 +304,34 @@ class GAT(nn.Module):
h = self.layer2(h)
return h
import networkx as nx
#############################################################################
# We then load the Cora dataset using DGL's built-in data module.
from dgl import DGLGraph
from dgl.data import citation_graph as citegrh
import networkx as nx
def load_cora_data():
data = citegrh.load_cora()
g = data[0]
mask = torch.BoolTensor(g.ndata['train_mask'])
return g, g.ndata['feat'], g.ndata['label'], mask
mask = torch.BoolTensor(g.ndata["train_mask"])
return g, g.ndata["feat"], g.ndata["label"], mask
##############################################################################
# The training loop is exactly the same as in the GCN tutorial.
import time
import numpy as np
g, features, labels, mask = load_cora_data()
# create the model, 2 heads, each head has hidden size 8
net = GAT(g,
in_dim=features.size()[1],
hidden_dim=8,
out_dim=7,
num_heads=2)
net = GAT(g, in_dim=features.size()[1], hidden_dim=8, out_dim=7, num_heads=2)
# create optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
......@@ -344,8 +353,11 @@ for epoch in range(30):
if epoch >= 3:
dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
epoch, loss.item(), np.mean(dur)))
print(
"Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
epoch, loss.item(), np.mean(dur)
)
)
#########################################################################
# Visualizing and understanding attention learned
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment