Unverified Commit dce89919 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Auto-reformat multiple python folders. (#5325)



* auto-reformat

* lintrunner

---------
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent ab812179
from ruamel.yaml.comments import CommentedMap from ruamel.yaml.comments import CommentedMap
...@@ -14,12 +13,14 @@ def deep_convert_dict(layer): ...@@ -14,12 +13,14 @@ def deep_convert_dict(layer):
return to_ret return to_ret
import collections.abc import collections.abc
def merge_comment(d, comment_dict, column=30): def merge_comment(d, comment_dict, column=30):
for k, v in comment_dict.items(): for k, v in comment_dict.items():
if isinstance(v, collections.abc.Mapping): if isinstance(v, collections.abc.Mapping):
d[k] = merge_comment(d.get(k, CommentedMap()), v) d[k] = merge_comment(d.get(k, CommentedMap()), v)
else: else:
d.yaml_add_eol_comment(v, key=k, column=column) d.yaml_add_eol_comment(v, key=k, column=column)
return d return d
\ No newline at end of file
#!/usr/bin/env python #!/usr/bin/env python
from setuptools import find_packages
from distutils.core import setup from distutils.core import setup
setup(name='dglgo', from setuptools import find_packages
version='0.0.2',
description='DGL', setup(
author='DGL Team', name="dglgo",
author_email='wmjlyjemaine@gmail.com', version="0.0.2",
packages=find_packages(), description="DGL",
install_requires=[ author="DGL Team",
'typer>=0.4.0', author_email="wmjlyjemaine@gmail.com",
'isort>=5.10.1', packages=find_packages(),
'autopep8>=1.6.0', install_requires=[
'numpydoc>=1.1.0', "typer>=0.4.0",
"pydantic>=1.9.0", "isort>=5.10.1",
"ruamel.yaml>=0.17.20", "autopep8>=1.6.0",
"PyYAML>=5.1", "numpydoc>=1.1.0",
"ogb>=1.3.3", "pydantic>=1.9.0",
"rdkit-pypi", "ruamel.yaml>=0.17.20",
"scikit-learn>=0.20.0" "PyYAML>=5.1",
], "ogb>=1.3.3",
package_data={"": ["./*"]}, "rdkit-pypi",
include_package_data=True, "scikit-learn>=0.20.0",
license='APACHE', ],
entry_points={ package_data={"": ["./*"]},
'console_scripts': [ include_package_data=True,
"dgl = dglgo.cli.cli:main" license="APACHE",
] entry_points={"console_scripts": ["dgl = dglgo.cli.cli:main"]},
}, url="https://github.com/dmlc/dgl",
url='https://github.com/dmlc/dgl', )
)
...@@ -14,16 +14,18 @@ ...@@ -14,16 +14,18 @@
# #
import os import os
import sys import sys
sys.path.insert(0, os.path.abspath('../../python'))
sys.path.insert(0, os.path.abspath("../../python"))
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
project = 'DGL' project = "DGL"
copyright = '2018, DGL Team' copyright = "2018, DGL Team"
author = 'DGL Team' author = "DGL Team"
import dgl import dgl
version = dgl.__version__ version = dgl.__version__
release = dgl.__version__ release = dgl.__version__
dglbackend = os.environ.get("DGLBACKEND", "pytorch") dglbackend = os.environ.get("DGLBACKEND", "pytorch")
...@@ -39,35 +41,35 @@ dglbackend = os.environ.get("DGLBACKEND", "pytorch") ...@@ -39,35 +41,35 @@ dglbackend = os.environ.get("DGLBACKEND", "pytorch")
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones. # ones.
extensions = [ extensions = [
'sphinx.ext.autodoc', "sphinx.ext.autodoc",
'sphinx.ext.autosummary', "sphinx.ext.autosummary",
'sphinx.ext.coverage', "sphinx.ext.coverage",
'sphinx.ext.mathjax', "sphinx.ext.mathjax",
'sphinx.ext.napoleon', "sphinx.ext.napoleon",
'sphinx.ext.viewcode', "sphinx.ext.viewcode",
'sphinx.ext.intersphinx', "sphinx.ext.intersphinx",
'sphinx.ext.graphviz', "sphinx.ext.graphviz",
'sphinxemoji.sphinxemoji', "sphinxemoji.sphinxemoji",
'sphinx_gallery.gen_gallery', "sphinx_gallery.gen_gallery",
'sphinx_copybutton', "sphinx_copybutton",
'nbsphinx', "nbsphinx",
'nbsphinx_link', "nbsphinx_link",
] ]
# Do not run notebooks on non-pytorch backends # Do not run notebooks on non-pytorch backends
if dglbackend != "pytorch": if dglbackend != "pytorch":
nbsphinx_execute = 'never' nbsphinx_execute = "never"
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates'] templates_path = ["_templates"]
# The suffix(es) of source filenames. # The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string: # You can specify multiple suffix as a list of string:
# #
source_suffix = ['.rst', '.md'] source_suffix = [".rst", ".md"]
# The master toctree document. # The master toctree document.
master_doc = 'index' master_doc = "index"
# The language for content autogenerated by Sphinx. Refer to documentation # The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages. # for a list of supported languages.
...@@ -90,7 +92,7 @@ pygments_style = None ...@@ -90,7 +92,7 @@ pygments_style = None
# The theme to use for HTML and HTML Help pages. See the documentation for # The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes. # a list of builtin themes.
# #
html_theme = 'sphinx_rtd_theme' html_theme = "sphinx_rtd_theme"
# Theme options are theme-specific and customize the look and feel of a theme # Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the # further. For a list of options available for each theme, see the
...@@ -101,8 +103,8 @@ html_theme = 'sphinx_rtd_theme' ...@@ -101,8 +103,8 @@ html_theme = 'sphinx_rtd_theme'
# Add any paths that contain custom static files (such as style sheets) here, # Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files, # relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css". # so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static'] html_static_path = ["_static"]
html_css_files = ['css/custom.css'] html_css_files = ["css/custom.css"]
# Custom sidebar templates, must be a dictionary that maps document names # Custom sidebar templates, must be a dictionary that maps document names
# to template names. # to template names.
...@@ -118,7 +120,7 @@ html_css_files = ['css/custom.css'] ...@@ -118,7 +120,7 @@ html_css_files = ['css/custom.css']
# -- Options for HTMLHelp output --------------------------------------------- # -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder. # Output file base name for HTML help builder.
htmlhelp_basename = 'dgldoc' htmlhelp_basename = "dgldoc"
# -- Options for LaTeX output ------------------------------------------------ # -- Options for LaTeX output ------------------------------------------------
...@@ -127,15 +129,12 @@ latex_elements = { ...@@ -127,15 +129,12 @@ latex_elements = {
# The paper size ('letterpaper' or 'a4paper'). # The paper size ('letterpaper' or 'a4paper').
# #
# 'papersize': 'letterpaper', # 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt'). # The font size ('10pt', '11pt' or '12pt').
# #
# 'pointsize': '10pt', # 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble. # Additional stuff for the LaTeX preamble.
# #
# 'preamble': '', # 'preamble': '',
# Latex figure (float) alignment # Latex figure (float) alignment
# #
# 'figure_align': 'htbp', # 'figure_align': 'htbp',
...@@ -145,8 +144,7 @@ latex_elements = { ...@@ -145,8 +144,7 @@ latex_elements = {
# (source start file, target name, title, # (source start file, target name, title,
# author, documentclass [howto, manual, or own class]). # author, documentclass [howto, manual, or own class]).
latex_documents = [ latex_documents = [
(master_doc, 'dgl.tex', 'DGL Documentation', (master_doc, "dgl.tex", "DGL Documentation", "DGL Team", "manual"),
'DGL Team', 'manual'),
] ]
...@@ -154,10 +152,7 @@ latex_documents = [ ...@@ -154,10 +152,7 @@ latex_documents = [
# One entry per manual page. List of tuples # One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section). # (source start file, name, description, authors, manual section).
man_pages = [ man_pages = [(master_doc, "dgl", "DGL Documentation", [author], 1)]
(master_doc, 'dgl', 'DGL Documentation',
[author], 1)
]
# -- Options for Texinfo output ---------------------------------------------- # -- Options for Texinfo output ----------------------------------------------
...@@ -166,9 +161,15 @@ man_pages = [ ...@@ -166,9 +161,15 @@ man_pages = [
# (source start file, target name, title, author, # (source start file, target name, title, author,
# dir menu entry, description, category) # dir menu entry, description, category)
texinfo_documents = [ texinfo_documents = [
(master_doc, 'dgl', 'DGL Documentation', (
author, 'dgl', 'Library for deep learning on graphs.', master_doc,
'Miscellaneous'), "dgl",
"DGL Documentation",
author,
"dgl",
"Library for deep learning on graphs.",
"Miscellaneous",
),
] ]
...@@ -187,64 +188,71 @@ epub_title = project ...@@ -187,64 +188,71 @@ epub_title = project
# epub_uid = '' # epub_uid = ''
# A list of files that should not be packed into the epub file. # A list of files that should not be packed into the epub file.
epub_exclude_files = ['search.html'] epub_exclude_files = ["search.html"]
# -- Extension configuration ------------------------------------------------- # -- Extension configuration -------------------------------------------------
autosummary_generate = True autosummary_generate = True
autodoc_member_order = 'alphabetical' autodoc_member_order = "alphabetical"
intersphinx_mapping = { intersphinx_mapping = {
'python': ('https://docs.python.org/{.major}'.format(sys.version_info), None), "python": (
'numpy': ('http://docs.scipy.org/doc/numpy/', None), "https://docs.python.org/{.major}".format(sys.version_info),
'scipy': ('http://docs.scipy.org/doc/scipy/reference', None), None,
'matplotlib': ('http://matplotlib.org/', None), ),
'networkx' : ('https://networkx.github.io/documentation/stable', None), "numpy": ("http://docs.scipy.org/doc/numpy/", None),
"scipy": ("http://docs.scipy.org/doc/scipy/reference", None),
"matplotlib": ("http://matplotlib.org/", None),
"networkx": ("https://networkx.github.io/documentation/stable", None),
} }
# sphinx gallery configurations # sphinx gallery configurations
from sphinx_gallery.sorting import FileNameSortKey from sphinx_gallery.sorting import FileNameSortKey
examples_dirs = ['../../tutorials/blitz', examples_dirs = [
'../../tutorials/large', "../../tutorials/blitz",
'../../tutorials/dist', "../../tutorials/large",
'../../tutorials/models', "../../tutorials/dist",
'../../tutorials/multi', "../../tutorials/models",
'../../tutorials/cpu'] # path to find sources "../../tutorials/multi",
gallery_dirs = ['tutorials/blitz/', "../../tutorials/cpu",
'tutorials/large/', ] # path to find sources
'tutorials/dist/', gallery_dirs = [
'tutorials/models/', "tutorials/blitz/",
'tutorials/multi/', "tutorials/large/",
'tutorials/cpu'] # path to generate docs "tutorials/dist/",
"tutorials/models/",
"tutorials/multi/",
"tutorials/cpu",
] # path to generate docs
if dglbackend != "pytorch": if dglbackend != "pytorch":
examples_dirs = [] examples_dirs = []
gallery_dirs = [] gallery_dirs = []
reference_url = { reference_url = {
'dgl' : None, "dgl": None,
'numpy': 'http://docs.scipy.org/doc/numpy/', "numpy": "http://docs.scipy.org/doc/numpy/",
'scipy': 'http://docs.scipy.org/doc/scipy/reference', "scipy": "http://docs.scipy.org/doc/scipy/reference",
'matplotlib': 'http://matplotlib.org/', "matplotlib": "http://matplotlib.org/",
'networkx' : 'https://networkx.github.io/documentation/stable', "networkx": "https://networkx.github.io/documentation/stable",
} }
sphinx_gallery_conf = { sphinx_gallery_conf = {
'backreferences_dir' : 'generated/backreferences', "backreferences_dir": "generated/backreferences",
'doc_module' : ('dgl', 'numpy'), "doc_module": ("dgl", "numpy"),
'examples_dirs' : examples_dirs, "examples_dirs": examples_dirs,
'gallery_dirs' : gallery_dirs, "gallery_dirs": gallery_dirs,
'within_subsection_order' : FileNameSortKey, "within_subsection_order": FileNameSortKey,
'filename_pattern' : '.py', "filename_pattern": ".py",
'download_all_examples' : False, "download_all_examples": False,
} }
# Compatibility for different backend when builds tutorials # Compatibility for different backend when builds tutorials
if dglbackend == 'mxnet': if dglbackend == "mxnet":
sphinx_gallery_conf['filename_pattern'] = "/*(?<=mx)\.py" sphinx_gallery_conf["filename_pattern"] = "/*(?<=mx)\.py"
if dglbackend == 'pytorch': if dglbackend == "pytorch":
sphinx_gallery_conf['filename_pattern'] = "/*(?<!mx)\.py" sphinx_gallery_conf["filename_pattern"] = "/*(?<!mx)\.py"
# sphinx-copybutton tool # sphinx-copybutton tool
copybutton_prompt_text = r'>>> |\.\.\. ' copybutton_prompt_text = r">>> |\.\.\. "
copybutton_prompt_is_regexp = True copybutton_prompt_is_regexp = True
from pytablewriter import RstGridTableWriter, MarkdownTableWriter
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data.gnn_benchmark import AmazonCoBuy, CoraFull, Coauthor
from dgl.data.karate import KarateClub # from dgl.data.qm9 import QM9
from dgl.data.gindt import GINDataset from dgl.data import CitationGraphDataset, PPIDataset, RedditDataset, TUDataset
from dgl.data.bitcoinotc import BitcoinOTC from dgl.data.bitcoinotc import BitcoinOTC
from dgl.data.gdelt import GDELT from dgl.data.gdelt import GDELT
from dgl.data.gindt import GINDataset
from dgl.data.gnn_benchmark import AmazonCoBuy, Coauthor, CoraFull
from dgl.data.icews18 import ICEWS18 from dgl.data.icews18 import ICEWS18
from dgl.data.karate import KarateClub
from dgl.data.qm7b import QM7b from dgl.data.qm7b import QM7b
# from dgl.data.qm9 import QM9 from pytablewriter import MarkdownTableWriter, RstGridTableWriter
from dgl.data import CitationGraphDataset, PPIDataset, RedditDataset, TUDataset
ds_list = { ds_list = {
"BitcoinOTC": "BitcoinOTC()", "BitcoinOTC": "BitcoinOTC()",
...@@ -40,9 +41,9 @@ writer = RstGridTableWriter() ...@@ -40,9 +41,9 @@ writer = RstGridTableWriter()
# writer = MarkdownTableWriter() # writer = MarkdownTableWriter()
extract_graph = lambda g: g if isinstance(g, DGLGraph) else g[0] extract_graph = lambda g: g if isinstance(g, DGLGraph) else g[0]
stat_list=[] stat_list = []
for k,v in ds_list.items(): for k, v in ds_list.items():
print(k, ' ', v) print(k, " ", v)
ds = eval(v.split("/")[0]) ds = eval(v.split("/")[0])
num_nodes = [] num_nodes = []
num_edges = [] num_edges = []
...@@ -58,10 +59,10 @@ for k,v in ds_list.items(): ...@@ -58,10 +59,10 @@ for k,v in ds_list.items():
"# of graphs": len(ds), "# of graphs": len(ds),
"Avg. # of nodes": np.mean(num_nodes), "Avg. # of nodes": np.mean(num_nodes),
"Avg. # of edges": np.mean(num_edges), "Avg. # of edges": np.mean(num_edges),
"Node field": ', '.join(list(gg.ndata.keys())), "Node field": ", ".join(list(gg.ndata.keys())),
"Edge field": ', '.join(list(gg.edata.keys())), "Edge field": ", ".join(list(gg.edata.keys())),
# "Graph field": ', '.join(ds[0][0].gdata.keys()) if hasattr(ds[0][0], "gdata") else "", # "Graph field": ', '.join(ds[0][0].gdata.keys()) if hasattr(ds[0][0], "gdata") else "",
"Temporal": hasattr(ds, "is_temporal") "Temporal": hasattr(ds, "is_temporal"),
} }
stat_list.append(dd) stat_list.append(dd)
......
...@@ -26,15 +26,14 @@ def get_sddmm_kernels_gpu(idtypes, dtypes): ...@@ -26,15 +26,14 @@ def get_sddmm_kernels_gpu(idtypes, dtypes):
return ret return ret
if __name__ == '__main__': if __name__ == "__main__":
binary_path = 'libfeatgraph_kernels.so' binary_path = "libfeatgraph_kernels.so"
kernels = [] kernels = []
idtypes = ['int32', 'int64'] idtypes = ["int32", "int64"]
dtypes = ['float16', 'float64', 'float32', 'int32', 'int64'] dtypes = ["float16", "float64", "float32", "int32", "int64"]
kernels += get_sddmm_kernels_gpu(idtypes, dtypes) kernels += get_sddmm_kernels_gpu(idtypes, dtypes)
# build kernels and export the module to libfeatgraph_kernels.so # build kernels and export the module to libfeatgraph_kernels.so
module = tvm.build(kernels, target='cuda', target_host='llvm') module = tvm.build(kernels, target="cuda", target_host="llvm")
module.export_library(binary_path) module.export_library(binary_path)
...@@ -4,8 +4,8 @@ from tvm import te ...@@ -4,8 +4,8 @@ from tvm import te
def sddmm_tree_reduction_gpu(idx_type, feat_type): def sddmm_tree_reduction_gpu(idx_type, feat_type):
""" SDDMM kernels on GPU optimized with Tree Reduction. """SDDMM kernels on GPU optimized with Tree Reduction.
Parameters Parameters
---------- ----------
idx_type : str idx_type : str
...@@ -19,35 +19,40 @@ def sddmm_tree_reduction_gpu(idx_type, feat_type): ...@@ -19,35 +19,40 @@ def sddmm_tree_reduction_gpu(idx_type, feat_type):
The result IRModule. The result IRModule.
""" """
# define vars and placeholders # define vars and placeholders
nnz = te.var('nnz', idx_type) nnz = te.var("nnz", idx_type)
num_rows = te.var('num_rows', idx_type) num_rows = te.var("num_rows", idx_type)
num_cols = te.var('num_cols', idx_type) num_cols = te.var("num_cols", idx_type)
H = te.var('num_heads', idx_type) H = te.var("num_heads", idx_type)
D = te.var('feat_len', idx_type) D = te.var("feat_len", idx_type)
row = te.placeholder((nnz,), idx_type, 'row') row = te.placeholder((nnz,), idx_type, "row")
col = te.placeholder((nnz,), idx_type, 'col') col = te.placeholder((nnz,), idx_type, "col")
ufeat = te.placeholder((num_rows, H, D), feat_type, 'ufeat') ufeat = te.placeholder((num_rows, H, D), feat_type, "ufeat")
vfeat = te.placeholder((num_cols, H, D), feat_type, 'vfeat') vfeat = te.placeholder((num_cols, H, D), feat_type, "vfeat")
# define edge computation function # define edge computation function
def edge_func(eid, h, i): def edge_func(eid, h, i):
k = te.reduce_axis((0, D), name='k') k = te.reduce_axis((0, D), name="k")
return te.sum(ufeat[row[eid], h, k] * vfeat[col[eid], h, k], axis=k) return te.sum(ufeat[row[eid], h, k] * vfeat[col[eid], h, k], axis=k)
out = te.compute((nnz, H, tvm.tir.IntImm(idx_type, 1)), edge_func, name='out')
out = te.compute(
(nnz, H, tvm.tir.IntImm(idx_type, 1)), edge_func, name="out"
)
# define schedules # define schedules
sched = te.create_schedule(out.op) sched = te.create_schedule(out.op)
edge_axis, head_axis, _ = out.op.axis edge_axis, head_axis, _ = out.op.axis
reduce_axis = out.op.reduce_axis[0] reduce_axis = out.op.reduce_axis[0]
_, red_inner = sched[out].split(reduce_axis, factor=32) _, red_inner = sched[out].split(reduce_axis, factor=32)
edge_outer, edge_inner = sched[out].split(edge_axis, factor=32) edge_outer, edge_inner = sched[out].split(edge_axis, factor=32)
sched[out].bind(red_inner, te.thread_axis('threadIdx.x')) sched[out].bind(red_inner, te.thread_axis("threadIdx.x"))
sched[out].bind(edge_inner, te.thread_axis('threadIdx.y')) sched[out].bind(edge_inner, te.thread_axis("threadIdx.y"))
sched[out].bind(edge_outer, te.thread_axis('blockIdx.x')) sched[out].bind(edge_outer, te.thread_axis("blockIdx.x"))
sched[out].bind(head_axis, te.thread_axis('blockIdx.y')) sched[out].bind(head_axis, te.thread_axis("blockIdx.y"))
return tvm.lower(sched, [row, col, ufeat, vfeat, out], return tvm.lower(
name='SDDMMTreeReduction_{}_{}'.format(idx_type, feat_type)) sched,
[row, col, ufeat, vfeat, out],
name="SDDMMTreeReduction_{}_{}".format(idx_type, feat_type),
)
if __name__ == '__main__': if __name__ == "__main__":
kernel0 = sddmm_tree_reduction_gpu('int32', 'float32') kernel0 = sddmm_tree_reduction_gpu("int32", "float32")
print(kernel0) print(kernel0)
import torch
import dgl import dgl
import dgl.backend as F import dgl.backend as F
import torch
g = dgl.rand_graph(10, 15).int().to(torch.device(0)) g = dgl.rand_graph(10, 15).int().to(torch.device(0))
gidx = g._graph gidx = g._graph
u = torch.rand((10,2,8), device=torch.device(0)) u = torch.rand((10, 2, 8), device=torch.device(0))
v = torch.rand((10,2,8), device=torch.device(0)) v = torch.rand((10, 2, 8), device=torch.device(0))
e = dgl.ops.gsddmm(g, 'dot', u, v) e = dgl.ops.gsddmm(g, "dot", u, v)
print(e) print(e)
e = torch.zeros((15,2,1), device=torch.device(0)) e = torch.zeros((15, 2, 1), device=torch.device(0))
u = F.zerocopy_to_dgl_ndarray(u) u = F.zerocopy_to_dgl_ndarray(u)
v = F.zerocopy_to_dgl_ndarray(v) v = F.zerocopy_to_dgl_ndarray(v)
e = F.zerocopy_to_dgl_ndarray_for_write(e) e = F.zerocopy_to_dgl_ndarray_for_write(e)
......
...@@ -22,13 +22,13 @@ networks with PyTorch. ...@@ -22,13 +22,13 @@ networks with PyTorch.
""" """
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
import torch
import torch.nn as nn
import torch.nn.functional as F
os.environ["DGLBACKEND"] = "pytorch"
import dgl import dgl
import dgl.data import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F
###################################################################### ######################################################################
# Overview of Node Classification with GNN # Overview of Node Classification with GNN
......
...@@ -31,11 +31,11 @@ By the end of this tutorial you will be able to: ...@@ -31,11 +31,11 @@ By the end of this tutorial you will be able to:
# #
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
import numpy as np
import torch
os.environ["DGLBACKEND"] = "pytorch"
import dgl import dgl
import numpy as np
import torch
g = dgl.graph(([0, 0, 0, 0, 0], [1, 2, 3, 4, 5]), num_nodes=6) g = dgl.graph(([0, 0, 0, 0, 0], [1, 2, 3, 4, 5]), num_nodes=6)
# Equivalently, PyTorch LongTensors also work. # Equivalently, PyTorch LongTensors also work.
......
...@@ -19,13 +19,13 @@ GNN for node classification <1_introduction>`. ...@@ -19,13 +19,13 @@ GNN for node classification <1_introduction>`.
""" """
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
import torch
import torch.nn as nn
import torch.nn.functional as F
os.environ["DGLBACKEND"] = "pytorch"
import dgl import dgl
import dgl.function as fn import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
###################################################################### ######################################################################
# Message passing and GNNs # Message passing and GNNs
......
...@@ -19,17 +19,17 @@ By the end of this tutorial you will be able to ...@@ -19,17 +19,17 @@ By the end of this tutorial you will be able to
import itertools import itertools
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.data
import numpy as np import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
import dgl.data
###################################################################### ######################################################################
# Overview of Link Prediction with GNN # Overview of Link Prediction with GNN
# ------------------------------------ # ------------------------------------
......
...@@ -14,13 +14,13 @@ By the end of this tutorial, you will be able to ...@@ -14,13 +14,13 @@ By the end of this tutorial, you will be able to
""" """
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
import torch
import torch.nn as nn
import torch.nn.functional as F
os.environ["DGLBACKEND"] = "pytorch"
import dgl import dgl
import dgl.data import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F
###################################################################### ######################################################################
# Overview of Graph Classification with GNN # Overview of Graph Classification with GNN
...@@ -54,6 +54,8 @@ print("Node feature dimensionality:", dataset.dim_nfeats) ...@@ -54,6 +54,8 @@ print("Node feature dimensionality:", dataset.dim_nfeats)
print("Number of graph categories:", dataset.gclasses) print("Number of graph categories:", dataset.gclasses)
from dgl.dataloading import GraphDataLoader
###################################################################### ######################################################################
# Defining Data Loader # Defining Data Loader
# -------------------- # --------------------
...@@ -74,8 +76,6 @@ print("Number of graph categories:", dataset.gclasses) ...@@ -74,8 +76,6 @@ print("Number of graph categories:", dataset.gclasses)
from torch.utils.data.sampler import SubsetRandomSampler from torch.utils.data.sampler import SubsetRandomSampler
from dgl.dataloading import GraphDataLoader
num_examples = len(dataset) num_examples = len(dataset)
num_train = int(num_examples * 0.8) num_train = int(num_examples * 0.8)
......
...@@ -88,10 +88,10 @@ interactions.head() ...@@ -88,10 +88,10 @@ interactions.head()
# #
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
import torch
os.environ["DGLBACKEND"] = "pytorch"
import dgl import dgl
import torch
from dgl.data import DGLDataset from dgl.data import DGLDataset
......
...@@ -26,10 +26,11 @@ Sampling for GNN Training <L0_neighbor_sampling_overview>`. ...@@ -26,10 +26,11 @@ Sampling for GNN Training <L0_neighbor_sampling_overview>`.
# #
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import dgl import dgl
import torch
import numpy as np import numpy as np
import torch
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
dataset = DglNodePropPredDataset("ogbn-arxiv") dataset = DglNodePropPredDataset("ogbn-arxiv")
...@@ -284,13 +285,14 @@ valid_dataloader = dgl.dataloading.DataLoader( ...@@ -284,13 +285,14 @@ valid_dataloader = dgl.dataloading.DataLoader(
) )
import sklearn.metrics
###################################################################### ######################################################################
# The following is a training loop that performs validation every epoch. # The following is a training loop that performs validation every epoch.
# It also saves the model with the best validation accuracy into a file. # It also saves the model with the best validation accuracy into a file.
# #
import tqdm import tqdm
import sklearn.metrics
best_accuracy = 0 best_accuracy = 0
best_model_path = "model.pt" best_model_path = "model.pt"
......
...@@ -53,10 +53,11 @@ Sampling for Node Classification <L1_large_node_classification>`. ...@@ -53,10 +53,11 @@ Sampling for Node Classification <L1_large_node_classification>`.
# #
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import dgl import dgl
import torch
import numpy as np import numpy as np
import torch
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
dataset = DglNodePropPredDataset("ogbn-arxiv") dataset = DglNodePropPredDataset("ogbn-arxiv")
...@@ -339,6 +340,8 @@ predictor = DotPredictor().to(device) ...@@ -339,6 +340,8 @@ predictor = DotPredictor().to(device)
opt = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters())) opt = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()))
import sklearn.metrics
###################################################################### ######################################################################
# The following is the training loop for link prediction and # The following is the training loop for link prediction and
# evaluation, and also saves the model that performs the best on the # evaluation, and also saves the model that performs the best on the
...@@ -346,7 +349,6 @@ opt = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters())) ...@@ -346,7 +349,6 @@ opt = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()))
# #
import tqdm import tqdm
import sklearn.metrics
best_accuracy = 0 best_accuracy = 0
best_model_path = "model.pt" best_model_path = "model.pt"
......
...@@ -14,30 +14,33 @@ for stochastic GNN training. It assumes that ...@@ -14,30 +14,33 @@ for stochastic GNN training. It assumes that
""" """
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import dgl import dgl
import torch
import numpy as np import numpy as np
import torch
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
dataset = DglNodePropPredDataset('ogbn-arxiv') dataset = DglNodePropPredDataset("ogbn-arxiv")
device = 'cpu' # change to 'cuda' for GPU device = "cpu" # change to 'cuda' for GPU
graph, node_labels = dataset[0] graph, node_labels = dataset[0]
# Add reverse edges since ogbn-arxiv is unidirectional. # Add reverse edges since ogbn-arxiv is unidirectional.
graph = dgl.add_reverse_edges(graph) graph = dgl.add_reverse_edges(graph)
graph.ndata['label'] = node_labels[:, 0] graph.ndata["label"] = node_labels[:, 0]
idx_split = dataset.get_idx_split() idx_split = dataset.get_idx_split()
train_nids = idx_split['train'] train_nids = idx_split["train"]
node_features = graph.ndata['feat'] node_features = graph.ndata["feat"]
sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4]) sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4])
train_dataloader = dgl.dataloading.DataLoader( train_dataloader = dgl.dataloading.DataLoader(
graph, train_nids, sampler, graph,
train_nids,
sampler,
batch_size=1024, batch_size=1024,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=0 num_workers=0,
) )
input_nodes, output_nodes, mfgs = next(iter(train_dataloader)) input_nodes, output_nodes, mfgs = next(iter(train_dataloader))
...@@ -75,8 +78,8 @@ print(mfg.num_src_nodes(), mfg.num_dst_nodes()) ...@@ -75,8 +78,8 @@ print(mfg.num_src_nodes(), mfg.num_dst_nodes())
# will do with ``ndata`` on the graphs you have seen earlier: # will do with ``ndata`` on the graphs you have seen earlier:
# #
mfg.srcdata['x'] = torch.zeros(mfg.num_src_nodes(), mfg.num_dst_nodes()) mfg.srcdata["x"] = torch.zeros(mfg.num_src_nodes(), mfg.num_dst_nodes())
dst_feat = mfg.dstdata['feat'] dst_feat = mfg.dstdata["feat"]
###################################################################### ######################################################################
...@@ -105,7 +108,11 @@ mfg.srcdata[dgl.NID], mfg.dstdata[dgl.NID] ...@@ -105,7 +108,11 @@ mfg.srcdata[dgl.NID], mfg.dstdata[dgl.NID]
# .. |image1| image:: https://data.dgl.ai/tutorial/img/bipartite.gif # .. |image1| image:: https://data.dgl.ai/tutorial/img/bipartite.gif
# #
print(torch.equal(mfg.srcdata[dgl.NID][:mfg.num_dst_nodes()], mfg.dstdata[dgl.NID])) print(
torch.equal(
mfg.srcdata[dgl.NID][: mfg.num_dst_nodes()], mfg.dstdata[dgl.NID]
)
)
###################################################################### ######################################################################
...@@ -113,7 +120,7 @@ print(torch.equal(mfg.srcdata[dgl.NID][:mfg.num_dst_nodes()], mfg.dstdata[dgl.NI ...@@ -113,7 +120,7 @@ print(torch.equal(mfg.srcdata[dgl.NID][:mfg.num_dst_nodes()], mfg.dstdata[dgl.NI
# :math:`h_u^{(l-1)}`: # :math:`h_u^{(l-1)}`:
# #
mfg.srcdata['h'] = torch.randn(mfg.num_src_nodes(), 10) mfg.srcdata["h"] = torch.randn(mfg.num_src_nodes(), 10)
###################################################################### ######################################################################
...@@ -132,8 +139,8 @@ mfg.srcdata['h'] = torch.randn(mfg.num_src_nodes(), 10) ...@@ -132,8 +139,8 @@ mfg.srcdata['h'] = torch.randn(mfg.num_src_nodes(), 10)
import dgl.function as fn import dgl.function as fn
mfg.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h')) mfg.update_all(message_func=fn.copy_u("h", "m"), reduce_func=fn.mean("m", "h"))
m_v = mfg.dstdata['h'] m_v = mfg.dstdata["h"]
m_v m_v
...@@ -147,6 +154,7 @@ import torch.nn as nn ...@@ -147,6 +154,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import tqdm import tqdm
class SAGEConv(nn.Module): class SAGEConv(nn.Module):
"""Graph convolution module used by the GraphSAGE model. """Graph convolution module used by the GraphSAGE model.
...@@ -157,6 +165,7 @@ class SAGEConv(nn.Module): ...@@ -157,6 +165,7 @@ class SAGEConv(nn.Module):
out_feat : int out_feat : int
Output feature size. Output feature size.
""" """
def __init__(self, in_feat, out_feat): def __init__(self, in_feat, out_feat):
super(SAGEConv, self).__init__() super(SAGEConv, self).__init__()
# A linear submodule for projecting the input and neighbor feature to the output. # A linear submodule for projecting the input and neighbor feature to the output.
...@@ -174,14 +183,15 @@ class SAGEConv(nn.Module): ...@@ -174,14 +183,15 @@ class SAGEConv(nn.Module):
""" """
with g.local_scope(): with g.local_scope():
h_src, h_dst = h h_src, h_dst = h
g.srcdata['h'] = h_src # <--- g.srcdata["h"] = h_src # <---
g.dstdata['h'] = h_dst # <--- g.dstdata["h"] = h_dst # <---
# update_all is a message passing API. # update_all is a message passing API.
g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_N')) g.update_all(fn.copy_u("h", "m"), fn.mean("m", "h_N"))
h_N = g.dstdata['h_N'] h_N = g.dstdata["h_N"]
h_total = torch.cat([h_dst, h_N], dim=1) # <--- h_total = torch.cat([h_dst, h_N], dim=1) # <---
return self.linear(h_total) return self.linear(h_total)
class Model(nn.Module): class Model(nn.Module):
def __init__(self, in_feats, h_feats, num_classes): def __init__(self, in_feats, h_feats, num_classes):
super(Model, self).__init__() super(Model, self).__init__()
...@@ -189,28 +199,31 @@ class Model(nn.Module): ...@@ -189,28 +199,31 @@ class Model(nn.Module):
self.conv2 = SAGEConv(h_feats, num_classes) self.conv2 = SAGEConv(h_feats, num_classes)
def forward(self, mfgs, x): def forward(self, mfgs, x):
h_dst = x[:mfgs[0].num_dst_nodes()] h_dst = x[: mfgs[0].num_dst_nodes()]
h = self.conv1(mfgs[0], (x, h_dst)) h = self.conv1(mfgs[0], (x, h_dst))
h = F.relu(h) h = F.relu(h)
h_dst = h[:mfgs[1].num_dst_nodes()] h_dst = h[: mfgs[1].num_dst_nodes()]
h = self.conv2(mfgs[1], (h, h_dst)) h = self.conv2(mfgs[1], (h, h_dst))
return h return h
sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4]) sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4])
train_dataloader = dgl.dataloading.DataLoader( train_dataloader = dgl.dataloading.DataLoader(
graph, train_nids, sampler, graph,
train_nids,
sampler,
device=device, device=device,
batch_size=1024, batch_size=1024,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=0 num_workers=0,
) )
model = Model(graph.ndata['feat'].shape[1], 128, dataset.num_classes).to(device) model = Model(graph.ndata["feat"].shape[1], 128, dataset.num_classes).to(device)
with tqdm.tqdm(train_dataloader) as tq: with tqdm.tqdm(train_dataloader) as tq:
for step, (input_nodes, output_nodes, mfgs) in enumerate(tq): for step, (input_nodes, output_nodes, mfgs) in enumerate(tq):
inputs = mfgs[0].srcdata['feat'] inputs = mfgs[0].srcdata["feat"]
labels = mfgs[-1].dstdata['label'] labels = mfgs[-1].dstdata["label"]
predictions = model(mfgs, inputs) predictions = model(mfgs, inputs)
...@@ -232,6 +245,7 @@ with tqdm.tqdm(train_dataloader) as tq: ...@@ -232,6 +245,7 @@ with tqdm.tqdm(train_dataloader) as tq:
# Say you start with a GNN module that works for full-graph training only: # Say you start with a GNN module that works for full-graph training only:
# #
class SAGEConv(nn.Module): class SAGEConv(nn.Module):
"""Graph convolution module used by the GraphSAGE model. """Graph convolution module used by the GraphSAGE model.
...@@ -242,6 +256,7 @@ class SAGEConv(nn.Module): ...@@ -242,6 +256,7 @@ class SAGEConv(nn.Module):
out_feat : int out_feat : int
Output feature size. Output feature size.
""" """
def __init__(self, in_feat, out_feat): def __init__(self, in_feat, out_feat):
super().__init__() super().__init__()
# A linear submodule for projecting the input and neighbor feature to the output. # A linear submodule for projecting the input and neighbor feature to the output.
...@@ -258,10 +273,13 @@ class SAGEConv(nn.Module): ...@@ -258,10 +273,13 @@ class SAGEConv(nn.Module):
The input node feature. The input node feature.
""" """
with g.local_scope(): with g.local_scope():
g.ndata['h'] = h g.ndata["h"] = h
# update_all is a message passing API. # update_all is a message passing API.
g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N')) g.update_all(
h_N = g.ndata['h_N'] message_func=fn.copy_u("h", "m"),
reduce_func=fn.mean("m", "h_N"),
)
h_N = g.ndata["h_N"]
h_total = torch.cat([h, h_N], dim=1) h_total = torch.cat([h, h_N], dim=1)
return self.linear(h_total) return self.linear(h_total)
...@@ -352,6 +370,7 @@ class SAGEConv(nn.Module): ...@@ -352,6 +370,7 @@ class SAGEConv(nn.Module):
# to something like the following: # to something like the following:
# #
class SAGEConvForBoth(nn.Module): class SAGEConvForBoth(nn.Module):
"""Graph convolution module used by the GraphSAGE model. """Graph convolution module used by the GraphSAGE model.
...@@ -362,6 +381,7 @@ class SAGEConvForBoth(nn.Module): ...@@ -362,6 +381,7 @@ class SAGEConvForBoth(nn.Module):
out_feat : int out_feat : int
Output feature size. Output feature size.
""" """
def __init__(self, in_feat, out_feat): def __init__(self, in_feat, out_feat):
super().__init__() super().__init__()
# A linear submodule for projecting the input and neighbor feature to the output. # A linear submodule for projecting the input and neighbor feature to the output.
...@@ -383,10 +403,13 @@ class SAGEConvForBoth(nn.Module): ...@@ -383,10 +403,13 @@ class SAGEConvForBoth(nn.Module):
else: else:
h_src = h_dst = h h_src = h_dst = h
g.srcdata['h'] = h_src g.srcdata["h"] = h_src
# update_all is a message passing API. # update_all is a message passing API.
g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N')) g.update_all(
h_N = g.ndata['h_N'] message_func=fn.copy_u("h", "m"),
reduce_func=fn.mean("m", "h_N"),
)
h_N = g.ndata["h_N"]
h_total = torch.cat([h_dst, h_N], dim=1) h_total = torch.cat([h_dst, h_N], dim=1)
return self.linear(h_total) return self.linear(h_total)
......
...@@ -20,189 +20,186 @@ Convolutional Networks <https://arxiv.org/pdf/1609.02907.pdf>`_). We explain ...@@ -20,189 +20,186 @@ Convolutional Networks <https://arxiv.org/pdf/1609.02907.pdf>`_). We explain
what is under the hood of the :class:`~dgl.nn.GraphConv` module. what is under the hood of the :class:`~dgl.nn.GraphConv` module.
The reader is expected to learn how to define a new GNN layer using DGL's The reader is expected to learn how to define a new GNN layer using DGL's
message passing APIs. message passing APIs.
""" """
############################################################################### ###############################################################################
# Model Overview # Model Overview
# ------------------------------------------ # ------------------------------------------
# GCN from the perspective of message passing # GCN from the perspective of message passing
# ``````````````````````````````````````````````` # ```````````````````````````````````````````````
# We describe a layer of graph convolutional neural network from a message # We describe a layer of graph convolutional neural network from a message
# passing perspective; the math can be found `here <math_>`_. # passing perspective; the math can be found `here <math_>`_.
# It boils down to the following step, for each node :math:`u`: # It boils down to the following step, for each node :math:`u`:
# #
# 1) Aggregate neighbors' representations :math:`h_{v}` to produce an # 1) Aggregate neighbors' representations :math:`h_{v}` to produce an
# intermediate representation :math:`\hat{h}_u`. 2) Transform the aggregated # intermediate representation :math:`\hat{h}_u`. 2) Transform the aggregated
# representation :math:`\hat{h}_{u}` with a linear projection followed by a # representation :math:`\hat{h}_{u}` with a linear projection followed by a
# non-linearity: :math:`h_{u} = f(W_{u} \hat{h}_u)`. # non-linearity: :math:`h_{u} = f(W_{u} \hat{h}_u)`.
# #
# We will implement step 1 with DGL message passing, and step 2 by # We will implement step 1 with DGL message passing, and step 2 by
# PyTorch ``nn.Module``. # PyTorch ``nn.Module``.
# #
# GCN implementation with DGL # GCN implementation with DGL
# `````````````````````````````````````````` # ``````````````````````````````````````````
# We first define the message and reduce function as usual. Since the # We first define the message and reduce function as usual. Since the
# aggregation on a node :math:`u` only involves summing over the neighbors' # aggregation on a node :math:`u` only involves summing over the neighbors'
# representations :math:`h_v`, we can simply use builtin functions: # representations :math:`h_v`, we can simply use builtin functions:
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
import torch as th os.environ["DGLBACKEND"] = "pytorch"
import torch.nn as nn import dgl
import torch.nn.functional as F import dgl.function as fn
import torch as th
import dgl import torch.nn as nn
import dgl.function as fn import torch.nn.functional as F
from dgl import DGLGraph from dgl import DGLGraph
gcn_msg = fn.copy_u(u="h", out="m") gcn_msg = fn.copy_u(u="h", out="m")
gcn_reduce = fn.sum(msg="m", out="h") gcn_reduce = fn.sum(msg="m", out="h")
############################################################################### ###############################################################################
# We then proceed to define the GCNLayer module. A GCNLayer essentially performs # We then proceed to define the GCNLayer module. A GCNLayer essentially performs
# message passing on all the nodes then applies a fully-connected layer. # message passing on all the nodes then applies a fully-connected layer.
# #
# .. note:: # .. note::
# #
# This is showing how to implement a GCN from scratch. DGL provides a more # This is showing how to implement a GCN from scratch. DGL provides a more
# efficient :class:`builtin GCN layer module <dgl.nn.pytorch.conv.GraphConv>`. # efficient :class:`builtin GCN layer module <dgl.nn.pytorch.conv.GraphConv>`.
# #
class GCNLayer(nn.Module): class GCNLayer(nn.Module):
def __init__(self, in_feats, out_feats): def __init__(self, in_feats, out_feats):
super(GCNLayer, self).__init__() super(GCNLayer, self).__init__()
self.linear = nn.Linear(in_feats, out_feats) self.linear = nn.Linear(in_feats, out_feats)
def forward(self, g, feature): def forward(self, g, feature):
# Creating a local scope so that all the stored ndata and edata # Creating a local scope so that all the stored ndata and edata
# (such as the `'h'` ndata below) are automatically popped out # (such as the `'h'` ndata below) are automatically popped out
# when the scope exits. # when the scope exits.
with g.local_scope(): with g.local_scope():
g.ndata["h"] = feature g.ndata["h"] = feature
g.update_all(gcn_msg, gcn_reduce) g.update_all(gcn_msg, gcn_reduce)
h = g.ndata["h"] h = g.ndata["h"]
return self.linear(h) return self.linear(h)
############################################################################### ###############################################################################
# The forward function is essentially the same as any other commonly seen NNs # The forward function is essentially the same as any other commonly seen NNs
# model in PyTorch. We can initialize GCN like any ``nn.Module``. For example, # model in PyTorch. We can initialize GCN like any ``nn.Module``. For example,
# let's define a simple neural network consisting of two GCN layers. Suppose we # let's define a simple neural network consisting of two GCN layers. Suppose we
# are training the classifier for the cora dataset (the input feature size is # are training the classifier for the cora dataset (the input feature size is
# 1433 and the number of classes is 7). The last GCN layer computes node embeddings, # 1433 and the number of classes is 7). The last GCN layer computes node embeddings,
# so the last layer in general does not apply activation. # so the last layer in general does not apply activation.
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
self.layer1 = GCNLayer(1433, 16) self.layer1 = GCNLayer(1433, 16)
self.layer2 = GCNLayer(16, 7) self.layer2 = GCNLayer(16, 7)
def forward(self, g, features): def forward(self, g, features):
x = F.relu(self.layer1(g, features)) x = F.relu(self.layer1(g, features))
x = self.layer2(g, x) x = self.layer2(g, x)
return x return x
net = Net() net = Net()
print(net) print(net)
############################################################################### ###############################################################################
# We load the cora dataset using DGL's built-in data module. # We load the cora dataset using DGL's built-in data module.
from dgl.data import CoraGraphDataset from dgl.data import CoraGraphDataset
def load_cora_data(): def load_cora_data():
dataset = CoraGraphDataset() dataset = CoraGraphDataset()
g = dataset[0] g = dataset[0]
features = g.ndata["feat"] features = g.ndata["feat"]
labels = g.ndata["label"] labels = g.ndata["label"]
train_mask = g.ndata["train_mask"] train_mask = g.ndata["train_mask"]
test_mask = g.ndata["test_mask"] test_mask = g.ndata["test_mask"]
return g, features, labels, train_mask, test_mask return g, features, labels, train_mask, test_mask
############################################################################### ###############################################################################
# When a model is trained, we can use the following method to evaluate # When a model is trained, we can use the following method to evaluate
# the performance of the model on the test dataset: # the performance of the model on the test dataset:
def evaluate(model, g, features, labels, mask): def evaluate(model, g, features, labels, mask):
model.eval() model.eval()
with th.no_grad(): with th.no_grad():
logits = model(g, features) logits = model(g, features)
logits = logits[mask] logits = logits[mask]
labels = labels[mask] labels = labels[mask]
_, indices = th.max(logits, dim=1) _, indices = th.max(logits, dim=1)
correct = th.sum(indices == labels) correct = th.sum(indices == labels)
return correct.item() * 1.0 / len(labels) return correct.item() * 1.0 / len(labels)
############################################################################### ###############################################################################
# We then train the network as follows: # We then train the network as follows:
import time import time
import numpy as np import numpy as np
g, features, labels, train_mask, test_mask = load_cora_data() g, features, labels, train_mask, test_mask = load_cora_data()
# Add edges between each node and itself to preserve old node representations # Add edges between each node and itself to preserve old node representations
g.add_edges(g.nodes(), g.nodes()) g.add_edges(g.nodes(), g.nodes())
optimizer = th.optim.Adam(net.parameters(), lr=1e-2) optimizer = th.optim.Adam(net.parameters(), lr=1e-2)
dur = [] dur = []
for epoch in range(50): for epoch in range(50):
if epoch >= 3: if epoch >= 3:
t0 = time.time() t0 = time.time()
net.train()
net.train() logits = net(g, features)
logits = net(g, features) logp = F.log_softmax(logits, 1)
logp = F.log_softmax(logits, 1) loss = F.nll_loss(logp[train_mask], labels[train_mask])
loss = F.nll_loss(logp[train_mask], labels[train_mask])
optimizer.zero_grad()
optimizer.zero_grad() loss.backward()
loss.backward() optimizer.step()
optimizer.step()
if epoch >= 3:
if epoch >= 3: dur.append(time.time() - t0)
dur.append(time.time() - t0) acc = evaluate(net, g, features, labels, test_mask)
print(
acc = evaluate(net, g, features, labels, test_mask) "Epoch {:05d} | Loss {:.4f} | Test Acc {:.4f} | Time(s) {:.4f}".format(
print( epoch, loss.item(), acc, np.mean(dur)
"Epoch {:05d} | Loss {:.4f} | Test Acc {:.4f} | Time(s) {:.4f}".format( )
epoch, loss.item(), acc, np.mean(dur) )
) ###############################################################################
) # .. _math:
#
############################################################################### # GCN in one formula
# .. _math: # ------------------
# # Mathematically, the GCN model follows this formula:
# GCN in one formula #
# ------------------ # :math:`H^{(l+1)} = \sigma(\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}H^{(l)}W^{(l)})`
# Mathematically, the GCN model follows this formula: #
# # Here, :math:`H^{(l)}` denotes the :math:`l^{th}` layer in the network,
# :math:`H^{(l+1)} = \sigma(\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}H^{(l)}W^{(l)})` # :math:`\sigma` is the non-linearity, and :math:`W` is the weight matrix for
# # this layer. :math:`\tilde{D}` and :math:`\tilde{A}` are separately the degree
# Here, :math:`H^{(l)}` denotes the :math:`l^{th}` layer in the network, # and adjacency matrices for the graph. With the superscript ~, we are referring
# :math:`\sigma` is the non-linearity, and :math:`W` is the weight matrix for # to the variant where we add additional edges between each node and itself to
# this layer. :math:`\tilde{D}` and :math:`\tilde{A}` are separately the degree # preserve its old representation in graph convolutions. The shape of the input
# and adjacency matrices for the graph. With the superscript ~, we are referring # :math:`H^{(0)}` is :math:`N \times D`, where :math:`N` is the number of nodes
# to the variant where we add additional edges between each node and itself to # and :math:`D` is the number of input features. We can chain up multiple
# preserve its old representation in graph convolutions. The shape of the input # layers as such to produce a node-level representation output with shape
# :math:`H^{(0)}` is :math:`N \times D`, where :math:`N` is the number of nodes # :math:`N \times F`, where :math:`F` is the dimension of the output node
# and :math:`D` is the number of input features. We can chain up multiple # feature vector.
# layers as such to produce a node-level representation output with shape #
# :math:`N \times F`, where :math:`F` is the dimension of the output node # The equation can be efficiently implemented using sparse matrix
# feature vector. # multiplication kernels (such as Kipf's
# # `pygcn <https://github.com/tkipf/pygcn>`_ code). The above DGL implementation
# The equation can be efficiently implemented using sparse matrix # in fact has already used this trick due to the use of builtin functions.
# multiplication kernels (such as Kipf's #
# `pygcn <https://github.com/tkipf/pygcn>`_ code). The above DGL implementation # Note that the tutorial code implements a simplified version of GCN where we
# in fact has already used this trick due to the use of builtin functions. # replace :math:`\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}` with
# # :math:`\tilde{A}`. For a full implementation, see our example
# Note that the tutorial code implements a simplified version of GCN where we # `here <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gcn>`_.
# replace :math:`\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}` with
# :math:`\tilde{A}`. For a full implementation, see our example
# `here <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gcn>`_.
This diff is collapsed.
This diff is collapsed.
...@@ -105,9 +105,8 @@ structure-free normalization, in the style of attention. ...@@ -105,9 +105,8 @@ structure-free normalization, in the style of attention.
# subpackage. Simply import the ``GATConv`` as the follows. # subpackage. Simply import the ``GATConv`` as the follows.
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
from dgl.nn.pytorch import GATConv
os.environ["DGLBACKEND"] = "pytorch"
############################################################### ###############################################################
# Readers can skip the following step-by-step explanation of the implementation and # Readers can skip the following step-by-step explanation of the implementation and
# jump to the `Put everything together`_ for training and visualization results. # jump to the `Put everything together`_ for training and visualization results.
...@@ -125,6 +124,7 @@ from dgl.nn.pytorch import GATConv ...@@ -125,6 +124,7 @@ from dgl.nn.pytorch import GATConv
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.nn.pytorch import GATConv
class GATLayer(nn.Module): class GATLayer(nn.Module):
...@@ -139,37 +139,38 @@ class GATLayer(nn.Module): ...@@ -139,37 +139,38 @@ class GATLayer(nn.Module):
def reset_parameters(self): def reset_parameters(self):
"""Reinitialize learnable parameters.""" """Reinitialize learnable parameters."""
gain = nn.init.calculate_gain('relu') gain = nn.init.calculate_gain("relu")
nn.init.xavier_normal_(self.fc.weight, gain=gain) nn.init.xavier_normal_(self.fc.weight, gain=gain)
nn.init.xavier_normal_(self.attn_fc.weight, gain=gain) nn.init.xavier_normal_(self.attn_fc.weight, gain=gain)
def edge_attention(self, edges): def edge_attention(self, edges):
# edge UDF for equation (2) # edge UDF for equation (2)
z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1) z2 = torch.cat([edges.src["z"], edges.dst["z"]], dim=1)
a = self.attn_fc(z2) a = self.attn_fc(z2)
return {'e': F.leaky_relu(a)} return {"e": F.leaky_relu(a)}
def message_func(self, edges): def message_func(self, edges):
# message UDF for equation (3) & (4) # message UDF for equation (3) & (4)
return {'z': edges.src['z'], 'e': edges.data['e']} return {"z": edges.src["z"], "e": edges.data["e"]}
def reduce_func(self, nodes): def reduce_func(self, nodes):
# reduce UDF for equation (3) & (4) # reduce UDF for equation (3) & (4)
# equation (3) # equation (3)
alpha = F.softmax(nodes.mailbox['e'], dim=1) alpha = F.softmax(nodes.mailbox["e"], dim=1)
# equation (4) # equation (4)
h = torch.sum(alpha * nodes.mailbox['z'], dim=1) h = torch.sum(alpha * nodes.mailbox["z"], dim=1)
return {'h': h} return {"h": h}
def forward(self, h): def forward(self, h):
# equation (1) # equation (1)
z = self.fc(h) z = self.fc(h)
self.g.ndata['z'] = z self.g.ndata["z"] = z
# equation (2) # equation (2)
self.g.apply_edges(self.edge_attention) self.g.apply_edges(self.edge_attention)
# equation (3) & (4) # equation (3) & (4)
self.g.update_all(self.message_func, self.reduce_func) self.g.update_all(self.message_func, self.reduce_func)
return self.g.ndata.pop('h') return self.g.ndata.pop("h")
################################################################## ##################################################################
# Equation (1) # Equation (1)
...@@ -195,11 +196,13 @@ class GATLayer(nn.Module): ...@@ -195,11 +196,13 @@ class GATLayer(nn.Module):
# ``apply_edges`` API. The argument to the ``apply_edges`` is an **Edge UDF**, # ``apply_edges`` API. The argument to the ``apply_edges`` is an **Edge UDF**,
# which is defined as below: # which is defined as below:
def edge_attention(self, edges): def edge_attention(self, edges):
# edge UDF for equation (2) # edge UDF for equation (2)
z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1) z2 = torch.cat([edges.src["z"], edges.dst["z"]], dim=1)
a = self.attn_fc(z2) a = self.attn_fc(z2)
return {'e' : F.leaky_relu(a)} return {"e": F.leaky_relu(a)}
########################################################################3 ########################################################################3
# Here, the dot product with the learnable weight vector :math:`\vec{a^{(l)}}` # Here, the dot product with the learnable weight vector :math:`\vec{a^{(l)}}`
...@@ -229,13 +232,15 @@ def edge_attention(self, edges): ...@@ -229,13 +232,15 @@ def edge_attention(self, edges):
# Both tasks first fetch data from the mailbox and then manipulate it on the # Both tasks first fetch data from the mailbox and then manipulate it on the
# second dimension (``dim=1``), on which the messages are batched. # second dimension (``dim=1``), on which the messages are batched.
def reduce_func(self, nodes): def reduce_func(self, nodes):
# reduce UDF for equation (3) & (4) # reduce UDF for equation (3) & (4)
# equation (3) # equation (3)
alpha = F.softmax(nodes.mailbox['e'], dim=1) alpha = F.softmax(nodes.mailbox["e"], dim=1)
# equation (4) # equation (4)
h = torch.sum(alpha * nodes.mailbox['z'], dim=1) h = torch.sum(alpha * nodes.mailbox["z"], dim=1)
return {'h' : h} return {"h": h}
##################################################################### #####################################################################
# Multi-head attention # Multi-head attention
...@@ -258,8 +263,9 @@ def reduce_func(self, nodes): ...@@ -258,8 +263,9 @@ def reduce_func(self, nodes):
# Use the above defined single-head ``GATLayer`` as the building block # Use the above defined single-head ``GATLayer`` as the building block
# for the ``MultiHeadGATLayer`` below: # for the ``MultiHeadGATLayer`` below:
class MultiHeadGATLayer(nn.Module): class MultiHeadGATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'): def __init__(self, g, in_dim, out_dim, num_heads, merge="cat"):
super(MultiHeadGATLayer, self).__init__() super(MultiHeadGATLayer, self).__init__()
self.heads = nn.ModuleList() self.heads = nn.ModuleList()
for i in range(num_heads): for i in range(num_heads):
...@@ -268,19 +274,21 @@ class MultiHeadGATLayer(nn.Module): ...@@ -268,19 +274,21 @@ class MultiHeadGATLayer(nn.Module):
def forward(self, h): def forward(self, h):
head_outs = [attn_head(h) for attn_head in self.heads] head_outs = [attn_head(h) for attn_head in self.heads]
if self.merge == 'cat': if self.merge == "cat":
# concat on the output feature dimension (dim=1) # concat on the output feature dimension (dim=1)
return torch.cat(head_outs, dim=1) return torch.cat(head_outs, dim=1)
else: else:
# merge using average # merge using average
return torch.mean(torch.stack(head_outs)) return torch.mean(torch.stack(head_outs))
########################################################################### ###########################################################################
# Put everything together # Put everything together
# ^^^^^^^^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^^^^^^^^
# #
# Now, you can define a two-layer GAT model. # Now, you can define a two-layer GAT model.
class GAT(nn.Module): class GAT(nn.Module):
def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads): def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
super(GAT, self).__init__() super(GAT, self).__init__()
...@@ -296,33 +304,34 @@ class GAT(nn.Module): ...@@ -296,33 +304,34 @@ class GAT(nn.Module):
h = self.layer2(h) h = self.layer2(h)
return h return h
import networkx as nx
############################################################################# #############################################################################
# We then load the Cora dataset using DGL's built-in data module. # We then load the Cora dataset using DGL's built-in data module.
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import citation_graph as citegrh from dgl.data import citation_graph as citegrh
import networkx as nx
def load_cora_data(): def load_cora_data():
data = citegrh.load_cora() data = citegrh.load_cora()
g = data[0] g = data[0]
mask = torch.BoolTensor(g.ndata['train_mask']) mask = torch.BoolTensor(g.ndata["train_mask"])
return g, g.ndata['feat'], g.ndata['label'], mask return g, g.ndata["feat"], g.ndata["label"], mask
############################################################################## ##############################################################################
# The training loop is exactly the same as in the GCN tutorial. # The training loop is exactly the same as in the GCN tutorial.
import time import time
import numpy as np import numpy as np
g, features, labels, mask = load_cora_data() g, features, labels, mask = load_cora_data()
# create the model, 2 heads, each head has hidden size 8 # create the model, 2 heads, each head has hidden size 8
net = GAT(g, net = GAT(g, in_dim=features.size()[1], hidden_dim=8, out_dim=7, num_heads=2)
in_dim=features.size()[1],
hidden_dim=8,
out_dim=7,
num_heads=2)
# create optimizer # create optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3) optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
...@@ -344,8 +353,11 @@ for epoch in range(30): ...@@ -344,8 +353,11 @@ for epoch in range(30):
if epoch >= 3: if epoch >= 3:
dur.append(time.time() - t0) dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format( print(
epoch, loss.item(), np.mean(dur))) "Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
epoch, loss.item(), np.mean(dur)
)
)
######################################################################### #########################################################################
# Visualizing and understanding attention learned # Visualizing and understanding attention learned
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment