Unverified Commit dce89919 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Auto-reformat multiple python folders. (#5325)



* auto-reformat

* lintrunner

---------
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent ab812179
from ruamel.yaml.comments import CommentedMap from ruamel.yaml.comments import CommentedMap
...@@ -14,8 +13,10 @@ def deep_convert_dict(layer): ...@@ -14,8 +13,10 @@ def deep_convert_dict(layer):
return to_ret return to_ret
import collections.abc import collections.abc
def merge_comment(d, comment_dict, column=30): def merge_comment(d, comment_dict, column=30):
for k, v in comment_dict.items(): for k, v in comment_dict.items():
if isinstance(v, collections.abc.Mapping): if isinstance(v, collections.abc.Mapping):
......
#!/usr/bin/env python #!/usr/bin/env python
from setuptools import find_packages
from distutils.core import setup from distutils.core import setup
setup(name='dglgo', from setuptools import find_packages
version='0.0.2',
description='DGL', setup(
author='DGL Team', name="dglgo",
author_email='wmjlyjemaine@gmail.com', version="0.0.2",
description="DGL",
author="DGL Team",
author_email="wmjlyjemaine@gmail.com",
packages=find_packages(), packages=find_packages(),
install_requires=[ install_requires=[
'typer>=0.4.0', "typer>=0.4.0",
'isort>=5.10.1', "isort>=5.10.1",
'autopep8>=1.6.0', "autopep8>=1.6.0",
'numpydoc>=1.1.0', "numpydoc>=1.1.0",
"pydantic>=1.9.0", "pydantic>=1.9.0",
"ruamel.yaml>=0.17.20", "ruamel.yaml>=0.17.20",
"PyYAML>=5.1", "PyYAML>=5.1",
"ogb>=1.3.3", "ogb>=1.3.3",
"rdkit-pypi", "rdkit-pypi",
"scikit-learn>=0.20.0" "scikit-learn>=0.20.0",
], ],
package_data={"": ["./*"]}, package_data={"": ["./*"]},
include_package_data=True, include_package_data=True,
license='APACHE', license="APACHE",
entry_points={ entry_points={"console_scripts": ["dgl = dglgo.cli.cli:main"]},
'console_scripts': [ url="https://github.com/dmlc/dgl",
"dgl = dglgo.cli.cli:main" )
]
},
url='https://github.com/dmlc/dgl',
)
...@@ -14,16 +14,18 @@ ...@@ -14,16 +14,18 @@
# #
import os import os
import sys import sys
sys.path.insert(0, os.path.abspath('../../python'))
sys.path.insert(0, os.path.abspath("../../python"))
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
project = 'DGL' project = "DGL"
copyright = '2018, DGL Team' copyright = "2018, DGL Team"
author = 'DGL Team' author = "DGL Team"
import dgl import dgl
version = dgl.__version__ version = dgl.__version__
release = dgl.__version__ release = dgl.__version__
dglbackend = os.environ.get("DGLBACKEND", "pytorch") dglbackend = os.environ.get("DGLBACKEND", "pytorch")
...@@ -39,35 +41,35 @@ dglbackend = os.environ.get("DGLBACKEND", "pytorch") ...@@ -39,35 +41,35 @@ dglbackend = os.environ.get("DGLBACKEND", "pytorch")
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones. # ones.
extensions = [ extensions = [
'sphinx.ext.autodoc', "sphinx.ext.autodoc",
'sphinx.ext.autosummary', "sphinx.ext.autosummary",
'sphinx.ext.coverage', "sphinx.ext.coverage",
'sphinx.ext.mathjax', "sphinx.ext.mathjax",
'sphinx.ext.napoleon', "sphinx.ext.napoleon",
'sphinx.ext.viewcode', "sphinx.ext.viewcode",
'sphinx.ext.intersphinx', "sphinx.ext.intersphinx",
'sphinx.ext.graphviz', "sphinx.ext.graphviz",
'sphinxemoji.sphinxemoji', "sphinxemoji.sphinxemoji",
'sphinx_gallery.gen_gallery', "sphinx_gallery.gen_gallery",
'sphinx_copybutton', "sphinx_copybutton",
'nbsphinx', "nbsphinx",
'nbsphinx_link', "nbsphinx_link",
] ]
# Do not run notebooks on non-pytorch backends # Do not run notebooks on non-pytorch backends
if dglbackend != "pytorch": if dglbackend != "pytorch":
nbsphinx_execute = 'never' nbsphinx_execute = "never"
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates'] templates_path = ["_templates"]
# The suffix(es) of source filenames. # The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string: # You can specify multiple suffix as a list of string:
# #
source_suffix = ['.rst', '.md'] source_suffix = [".rst", ".md"]
# The master toctree document. # The master toctree document.
master_doc = 'index' master_doc = "index"
# The language for content autogenerated by Sphinx. Refer to documentation # The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages. # for a list of supported languages.
...@@ -90,7 +92,7 @@ pygments_style = None ...@@ -90,7 +92,7 @@ pygments_style = None
# The theme to use for HTML and HTML Help pages. See the documentation for # The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes. # a list of builtin themes.
# #
html_theme = 'sphinx_rtd_theme' html_theme = "sphinx_rtd_theme"
# Theme options are theme-specific and customize the look and feel of a theme # Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the # further. For a list of options available for each theme, see the
...@@ -101,8 +103,8 @@ html_theme = 'sphinx_rtd_theme' ...@@ -101,8 +103,8 @@ html_theme = 'sphinx_rtd_theme'
# Add any paths that contain custom static files (such as style sheets) here, # Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files, # relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css". # so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static'] html_static_path = ["_static"]
html_css_files = ['css/custom.css'] html_css_files = ["css/custom.css"]
# Custom sidebar templates, must be a dictionary that maps document names # Custom sidebar templates, must be a dictionary that maps document names
# to template names. # to template names.
...@@ -118,7 +120,7 @@ html_css_files = ['css/custom.css'] ...@@ -118,7 +120,7 @@ html_css_files = ['css/custom.css']
# -- Options for HTMLHelp output --------------------------------------------- # -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder. # Output file base name for HTML help builder.
htmlhelp_basename = 'dgldoc' htmlhelp_basename = "dgldoc"
# -- Options for LaTeX output ------------------------------------------------ # -- Options for LaTeX output ------------------------------------------------
...@@ -127,15 +129,12 @@ latex_elements = { ...@@ -127,15 +129,12 @@ latex_elements = {
# The paper size ('letterpaper' or 'a4paper'). # The paper size ('letterpaper' or 'a4paper').
# #
# 'papersize': 'letterpaper', # 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt'). # The font size ('10pt', '11pt' or '12pt').
# #
# 'pointsize': '10pt', # 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble. # Additional stuff for the LaTeX preamble.
# #
# 'preamble': '', # 'preamble': '',
# Latex figure (float) alignment # Latex figure (float) alignment
# #
# 'figure_align': 'htbp', # 'figure_align': 'htbp',
...@@ -145,8 +144,7 @@ latex_elements = { ...@@ -145,8 +144,7 @@ latex_elements = {
# (source start file, target name, title, # (source start file, target name, title,
# author, documentclass [howto, manual, or own class]). # author, documentclass [howto, manual, or own class]).
latex_documents = [ latex_documents = [
(master_doc, 'dgl.tex', 'DGL Documentation', (master_doc, "dgl.tex", "DGL Documentation", "DGL Team", "manual"),
'DGL Team', 'manual'),
] ]
...@@ -154,10 +152,7 @@ latex_documents = [ ...@@ -154,10 +152,7 @@ latex_documents = [
# One entry per manual page. List of tuples # One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section). # (source start file, name, description, authors, manual section).
man_pages = [ man_pages = [(master_doc, "dgl", "DGL Documentation", [author], 1)]
(master_doc, 'dgl', 'DGL Documentation',
[author], 1)
]
# -- Options for Texinfo output ---------------------------------------------- # -- Options for Texinfo output ----------------------------------------------
...@@ -166,9 +161,15 @@ man_pages = [ ...@@ -166,9 +161,15 @@ man_pages = [
# (source start file, target name, title, author, # (source start file, target name, title, author,
# dir menu entry, description, category) # dir menu entry, description, category)
texinfo_documents = [ texinfo_documents = [
(master_doc, 'dgl', 'DGL Documentation', (
author, 'dgl', 'Library for deep learning on graphs.', master_doc,
'Miscellaneous'), "dgl",
"DGL Documentation",
author,
"dgl",
"Library for deep learning on graphs.",
"Miscellaneous",
),
] ]
...@@ -187,64 +188,71 @@ epub_title = project ...@@ -187,64 +188,71 @@ epub_title = project
# epub_uid = '' # epub_uid = ''
# A list of files that should not be packed into the epub file. # A list of files that should not be packed into the epub file.
epub_exclude_files = ['search.html'] epub_exclude_files = ["search.html"]
# -- Extension configuration ------------------------------------------------- # -- Extension configuration -------------------------------------------------
autosummary_generate = True autosummary_generate = True
autodoc_member_order = 'alphabetical' autodoc_member_order = "alphabetical"
intersphinx_mapping = { intersphinx_mapping = {
'python': ('https://docs.python.org/{.major}'.format(sys.version_info), None), "python": (
'numpy': ('http://docs.scipy.org/doc/numpy/', None), "https://docs.python.org/{.major}".format(sys.version_info),
'scipy': ('http://docs.scipy.org/doc/scipy/reference', None), None,
'matplotlib': ('http://matplotlib.org/', None), ),
'networkx' : ('https://networkx.github.io/documentation/stable', None), "numpy": ("http://docs.scipy.org/doc/numpy/", None),
"scipy": ("http://docs.scipy.org/doc/scipy/reference", None),
"matplotlib": ("http://matplotlib.org/", None),
"networkx": ("https://networkx.github.io/documentation/stable", None),
} }
# sphinx gallery configurations # sphinx gallery configurations
from sphinx_gallery.sorting import FileNameSortKey from sphinx_gallery.sorting import FileNameSortKey
examples_dirs = ['../../tutorials/blitz', examples_dirs = [
'../../tutorials/large', "../../tutorials/blitz",
'../../tutorials/dist', "../../tutorials/large",
'../../tutorials/models', "../../tutorials/dist",
'../../tutorials/multi', "../../tutorials/models",
'../../tutorials/cpu'] # path to find sources "../../tutorials/multi",
gallery_dirs = ['tutorials/blitz/', "../../tutorials/cpu",
'tutorials/large/', ] # path to find sources
'tutorials/dist/', gallery_dirs = [
'tutorials/models/', "tutorials/blitz/",
'tutorials/multi/', "tutorials/large/",
'tutorials/cpu'] # path to generate docs "tutorials/dist/",
"tutorials/models/",
"tutorials/multi/",
"tutorials/cpu",
] # path to generate docs
if dglbackend != "pytorch": if dglbackend != "pytorch":
examples_dirs = [] examples_dirs = []
gallery_dirs = [] gallery_dirs = []
reference_url = { reference_url = {
'dgl' : None, "dgl": None,
'numpy': 'http://docs.scipy.org/doc/numpy/', "numpy": "http://docs.scipy.org/doc/numpy/",
'scipy': 'http://docs.scipy.org/doc/scipy/reference', "scipy": "http://docs.scipy.org/doc/scipy/reference",
'matplotlib': 'http://matplotlib.org/', "matplotlib": "http://matplotlib.org/",
'networkx' : 'https://networkx.github.io/documentation/stable', "networkx": "https://networkx.github.io/documentation/stable",
} }
sphinx_gallery_conf = { sphinx_gallery_conf = {
'backreferences_dir' : 'generated/backreferences', "backreferences_dir": "generated/backreferences",
'doc_module' : ('dgl', 'numpy'), "doc_module": ("dgl", "numpy"),
'examples_dirs' : examples_dirs, "examples_dirs": examples_dirs,
'gallery_dirs' : gallery_dirs, "gallery_dirs": gallery_dirs,
'within_subsection_order' : FileNameSortKey, "within_subsection_order": FileNameSortKey,
'filename_pattern' : '.py', "filename_pattern": ".py",
'download_all_examples' : False, "download_all_examples": False,
} }
# Compatibility for different backend when builds tutorials # Compatibility for different backend when builds tutorials
if dglbackend == 'mxnet': if dglbackend == "mxnet":
sphinx_gallery_conf['filename_pattern'] = "/*(?<=mx)\.py" sphinx_gallery_conf["filename_pattern"] = "/*(?<=mx)\.py"
if dglbackend == 'pytorch': if dglbackend == "pytorch":
sphinx_gallery_conf['filename_pattern'] = "/*(?<!mx)\.py" sphinx_gallery_conf["filename_pattern"] = "/*(?<!mx)\.py"
# sphinx-copybutton tool # sphinx-copybutton tool
copybutton_prompt_text = r'>>> |\.\.\. ' copybutton_prompt_text = r">>> |\.\.\. "
copybutton_prompt_is_regexp = True copybutton_prompt_is_regexp = True
from pytablewriter import RstGridTableWriter, MarkdownTableWriter
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data.gnn_benchmark import AmazonCoBuy, CoraFull, Coauthor
from dgl.data.karate import KarateClub # from dgl.data.qm9 import QM9
from dgl.data.gindt import GINDataset from dgl.data import CitationGraphDataset, PPIDataset, RedditDataset, TUDataset
from dgl.data.bitcoinotc import BitcoinOTC from dgl.data.bitcoinotc import BitcoinOTC
from dgl.data.gdelt import GDELT from dgl.data.gdelt import GDELT
from dgl.data.gindt import GINDataset
from dgl.data.gnn_benchmark import AmazonCoBuy, Coauthor, CoraFull
from dgl.data.icews18 import ICEWS18 from dgl.data.icews18 import ICEWS18
from dgl.data.karate import KarateClub
from dgl.data.qm7b import QM7b from dgl.data.qm7b import QM7b
# from dgl.data.qm9 import QM9 from pytablewriter import MarkdownTableWriter, RstGridTableWriter
from dgl.data import CitationGraphDataset, PPIDataset, RedditDataset, TUDataset
ds_list = { ds_list = {
"BitcoinOTC": "BitcoinOTC()", "BitcoinOTC": "BitcoinOTC()",
...@@ -40,9 +41,9 @@ writer = RstGridTableWriter() ...@@ -40,9 +41,9 @@ writer = RstGridTableWriter()
# writer = MarkdownTableWriter() # writer = MarkdownTableWriter()
extract_graph = lambda g: g if isinstance(g, DGLGraph) else g[0] extract_graph = lambda g: g if isinstance(g, DGLGraph) else g[0]
stat_list=[] stat_list = []
for k,v in ds_list.items(): for k, v in ds_list.items():
print(k, ' ', v) print(k, " ", v)
ds = eval(v.split("/")[0]) ds = eval(v.split("/")[0])
num_nodes = [] num_nodes = []
num_edges = [] num_edges = []
...@@ -58,10 +59,10 @@ for k,v in ds_list.items(): ...@@ -58,10 +59,10 @@ for k,v in ds_list.items():
"# of graphs": len(ds), "# of graphs": len(ds),
"Avg. # of nodes": np.mean(num_nodes), "Avg. # of nodes": np.mean(num_nodes),
"Avg. # of edges": np.mean(num_edges), "Avg. # of edges": np.mean(num_edges),
"Node field": ', '.join(list(gg.ndata.keys())), "Node field": ", ".join(list(gg.ndata.keys())),
"Edge field": ', '.join(list(gg.edata.keys())), "Edge field": ", ".join(list(gg.edata.keys())),
# "Graph field": ', '.join(ds[0][0].gdata.keys()) if hasattr(ds[0][0], "gdata") else "", # "Graph field": ', '.join(ds[0][0].gdata.keys()) if hasattr(ds[0][0], "gdata") else "",
"Temporal": hasattr(ds, "is_temporal") "Temporal": hasattr(ds, "is_temporal"),
} }
stat_list.append(dd) stat_list.append(dd)
......
...@@ -26,15 +26,14 @@ def get_sddmm_kernels_gpu(idtypes, dtypes): ...@@ -26,15 +26,14 @@ def get_sddmm_kernels_gpu(idtypes, dtypes):
return ret return ret
if __name__ == '__main__': if __name__ == "__main__":
binary_path = 'libfeatgraph_kernels.so' binary_path = "libfeatgraph_kernels.so"
kernels = [] kernels = []
idtypes = ['int32', 'int64'] idtypes = ["int32", "int64"]
dtypes = ['float16', 'float64', 'float32', 'int32', 'int64'] dtypes = ["float16", "float64", "float32", "int32", "int64"]
kernels += get_sddmm_kernels_gpu(idtypes, dtypes) kernels += get_sddmm_kernels_gpu(idtypes, dtypes)
# build kernels and export the module to libfeatgraph_kernels.so # build kernels and export the module to libfeatgraph_kernels.so
module = tvm.build(kernels, target='cuda', target_host='llvm') module = tvm.build(kernels, target="cuda", target_host="llvm")
module.export_library(binary_path) module.export_library(binary_path)
...@@ -4,7 +4,7 @@ from tvm import te ...@@ -4,7 +4,7 @@ from tvm import te
def sddmm_tree_reduction_gpu(idx_type, feat_type): def sddmm_tree_reduction_gpu(idx_type, feat_type):
""" SDDMM kernels on GPU optimized with Tree Reduction. """SDDMM kernels on GPU optimized with Tree Reduction.
Parameters Parameters
---------- ----------
...@@ -19,35 +19,40 @@ def sddmm_tree_reduction_gpu(idx_type, feat_type): ...@@ -19,35 +19,40 @@ def sddmm_tree_reduction_gpu(idx_type, feat_type):
The result IRModule. The result IRModule.
""" """
# define vars and placeholders # define vars and placeholders
nnz = te.var('nnz', idx_type) nnz = te.var("nnz", idx_type)
num_rows = te.var('num_rows', idx_type) num_rows = te.var("num_rows", idx_type)
num_cols = te.var('num_cols', idx_type) num_cols = te.var("num_cols", idx_type)
H = te.var('num_heads', idx_type) H = te.var("num_heads", idx_type)
D = te.var('feat_len', idx_type) D = te.var("feat_len", idx_type)
row = te.placeholder((nnz,), idx_type, 'row') row = te.placeholder((nnz,), idx_type, "row")
col = te.placeholder((nnz,), idx_type, 'col') col = te.placeholder((nnz,), idx_type, "col")
ufeat = te.placeholder((num_rows, H, D), feat_type, 'ufeat') ufeat = te.placeholder((num_rows, H, D), feat_type, "ufeat")
vfeat = te.placeholder((num_cols, H, D), feat_type, 'vfeat') vfeat = te.placeholder((num_cols, H, D), feat_type, "vfeat")
# define edge computation function # define edge computation function
def edge_func(eid, h, i): def edge_func(eid, h, i):
k = te.reduce_axis((0, D), name='k') k = te.reduce_axis((0, D), name="k")
return te.sum(ufeat[row[eid], h, k] * vfeat[col[eid], h, k], axis=k) return te.sum(ufeat[row[eid], h, k] * vfeat[col[eid], h, k], axis=k)
out = te.compute((nnz, H, tvm.tir.IntImm(idx_type, 1)), edge_func, name='out')
out = te.compute(
(nnz, H, tvm.tir.IntImm(idx_type, 1)), edge_func, name="out"
)
# define schedules # define schedules
sched = te.create_schedule(out.op) sched = te.create_schedule(out.op)
edge_axis, head_axis, _ = out.op.axis edge_axis, head_axis, _ = out.op.axis
reduce_axis = out.op.reduce_axis[0] reduce_axis = out.op.reduce_axis[0]
_, red_inner = sched[out].split(reduce_axis, factor=32) _, red_inner = sched[out].split(reduce_axis, factor=32)
edge_outer, edge_inner = sched[out].split(edge_axis, factor=32) edge_outer, edge_inner = sched[out].split(edge_axis, factor=32)
sched[out].bind(red_inner, te.thread_axis('threadIdx.x')) sched[out].bind(red_inner, te.thread_axis("threadIdx.x"))
sched[out].bind(edge_inner, te.thread_axis('threadIdx.y')) sched[out].bind(edge_inner, te.thread_axis("threadIdx.y"))
sched[out].bind(edge_outer, te.thread_axis('blockIdx.x')) sched[out].bind(edge_outer, te.thread_axis("blockIdx.x"))
sched[out].bind(head_axis, te.thread_axis('blockIdx.y')) sched[out].bind(head_axis, te.thread_axis("blockIdx.y"))
return tvm.lower(sched, [row, col, ufeat, vfeat, out], return tvm.lower(
name='SDDMMTreeReduction_{}_{}'.format(idx_type, feat_type)) sched,
[row, col, ufeat, vfeat, out],
name="SDDMMTreeReduction_{}_{}".format(idx_type, feat_type),
)
if __name__ == '__main__': if __name__ == "__main__":
kernel0 = sddmm_tree_reduction_gpu('int32', 'float32') kernel0 = sddmm_tree_reduction_gpu("int32", "float32")
print(kernel0) print(kernel0)
import torch
import dgl import dgl
import dgl.backend as F import dgl.backend as F
import torch
g = dgl.rand_graph(10, 15).int().to(torch.device(0)) g = dgl.rand_graph(10, 15).int().to(torch.device(0))
gidx = g._graph gidx = g._graph
u = torch.rand((10,2,8), device=torch.device(0)) u = torch.rand((10, 2, 8), device=torch.device(0))
v = torch.rand((10,2,8), device=torch.device(0)) v = torch.rand((10, 2, 8), device=torch.device(0))
e = dgl.ops.gsddmm(g, 'dot', u, v) e = dgl.ops.gsddmm(g, "dot", u, v)
print(e) print(e)
e = torch.zeros((15,2,1), device=torch.device(0)) e = torch.zeros((15, 2, 1), device=torch.device(0))
u = F.zerocopy_to_dgl_ndarray(u) u = F.zerocopy_to_dgl_ndarray(u)
v = F.zerocopy_to_dgl_ndarray(v) v = F.zerocopy_to_dgl_ndarray(v)
e = F.zerocopy_to_dgl_ndarray_for_write(e) e = F.zerocopy_to_dgl_ndarray_for_write(e)
......
...@@ -22,13 +22,13 @@ networks with PyTorch. ...@@ -22,13 +22,13 @@ networks with PyTorch.
""" """
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
import torch
import torch.nn as nn
import torch.nn.functional as F
os.environ["DGLBACKEND"] = "pytorch"
import dgl import dgl
import dgl.data import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F
###################################################################### ######################################################################
# Overview of Node Classification with GNN # Overview of Node Classification with GNN
......
...@@ -31,11 +31,11 @@ By the end of this tutorial you will be able to: ...@@ -31,11 +31,11 @@ By the end of this tutorial you will be able to:
# #
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
import numpy as np
import torch
os.environ["DGLBACKEND"] = "pytorch"
import dgl import dgl
import numpy as np
import torch
g = dgl.graph(([0, 0, 0, 0, 0], [1, 2, 3, 4, 5]), num_nodes=6) g = dgl.graph(([0, 0, 0, 0, 0], [1, 2, 3, 4, 5]), num_nodes=6)
# Equivalently, PyTorch LongTensors also work. # Equivalently, PyTorch LongTensors also work.
......
...@@ -19,13 +19,13 @@ GNN for node classification <1_introduction>`. ...@@ -19,13 +19,13 @@ GNN for node classification <1_introduction>`.
""" """
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
import torch
import torch.nn as nn
import torch.nn.functional as F
os.environ["DGLBACKEND"] = "pytorch"
import dgl import dgl
import dgl.function as fn import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
###################################################################### ######################################################################
# Message passing and GNNs # Message passing and GNNs
......
...@@ -19,17 +19,17 @@ By the end of this tutorial you will be able to ...@@ -19,17 +19,17 @@ By the end of this tutorial you will be able to
import itertools import itertools
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.data
import numpy as np import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
import dgl.data
###################################################################### ######################################################################
# Overview of Link Prediction with GNN # Overview of Link Prediction with GNN
# ------------------------------------ # ------------------------------------
......
...@@ -14,13 +14,13 @@ By the end of this tutorial, you will be able to ...@@ -14,13 +14,13 @@ By the end of this tutorial, you will be able to
""" """
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
import torch
import torch.nn as nn
import torch.nn.functional as F
os.environ["DGLBACKEND"] = "pytorch"
import dgl import dgl
import dgl.data import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F
###################################################################### ######################################################################
# Overview of Graph Classification with GNN # Overview of Graph Classification with GNN
...@@ -54,6 +54,8 @@ print("Node feature dimensionality:", dataset.dim_nfeats) ...@@ -54,6 +54,8 @@ print("Node feature dimensionality:", dataset.dim_nfeats)
print("Number of graph categories:", dataset.gclasses) print("Number of graph categories:", dataset.gclasses)
from dgl.dataloading import GraphDataLoader
###################################################################### ######################################################################
# Defining Data Loader # Defining Data Loader
# -------------------- # --------------------
...@@ -74,8 +76,6 @@ print("Number of graph categories:", dataset.gclasses) ...@@ -74,8 +76,6 @@ print("Number of graph categories:", dataset.gclasses)
from torch.utils.data.sampler import SubsetRandomSampler from torch.utils.data.sampler import SubsetRandomSampler
from dgl.dataloading import GraphDataLoader
num_examples = len(dataset) num_examples = len(dataset)
num_train = int(num_examples * 0.8) num_train = int(num_examples * 0.8)
......
...@@ -88,10 +88,10 @@ interactions.head() ...@@ -88,10 +88,10 @@ interactions.head()
# #
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
import torch
os.environ["DGLBACKEND"] = "pytorch"
import dgl import dgl
import torch
from dgl.data import DGLDataset from dgl.data import DGLDataset
......
...@@ -26,10 +26,11 @@ Sampling for GNN Training <L0_neighbor_sampling_overview>`. ...@@ -26,10 +26,11 @@ Sampling for GNN Training <L0_neighbor_sampling_overview>`.
# #
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import dgl import dgl
import torch
import numpy as np import numpy as np
import torch
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
dataset = DglNodePropPredDataset("ogbn-arxiv") dataset = DglNodePropPredDataset("ogbn-arxiv")
...@@ -284,13 +285,14 @@ valid_dataloader = dgl.dataloading.DataLoader( ...@@ -284,13 +285,14 @@ valid_dataloader = dgl.dataloading.DataLoader(
) )
import sklearn.metrics
###################################################################### ######################################################################
# The following is a training loop that performs validation every epoch. # The following is a training loop that performs validation every epoch.
# It also saves the model with the best validation accuracy into a file. # It also saves the model with the best validation accuracy into a file.
# #
import tqdm import tqdm
import sklearn.metrics
best_accuracy = 0 best_accuracy = 0
best_model_path = "model.pt" best_model_path = "model.pt"
......
...@@ -53,10 +53,11 @@ Sampling for Node Classification <L1_large_node_classification>`. ...@@ -53,10 +53,11 @@ Sampling for Node Classification <L1_large_node_classification>`.
# #
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import dgl import dgl
import torch
import numpy as np import numpy as np
import torch
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
dataset = DglNodePropPredDataset("ogbn-arxiv") dataset = DglNodePropPredDataset("ogbn-arxiv")
...@@ -339,6 +340,8 @@ predictor = DotPredictor().to(device) ...@@ -339,6 +340,8 @@ predictor = DotPredictor().to(device)
opt = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters())) opt = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()))
import sklearn.metrics
###################################################################### ######################################################################
# The following is the training loop for link prediction and # The following is the training loop for link prediction and
# evaluation, and also saves the model that performs the best on the # evaluation, and also saves the model that performs the best on the
...@@ -346,7 +349,6 @@ opt = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters())) ...@@ -346,7 +349,6 @@ opt = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()))
# #
import tqdm import tqdm
import sklearn.metrics
best_accuracy = 0 best_accuracy = 0
best_model_path = "model.pt" best_model_path = "model.pt"
......
...@@ -14,30 +14,33 @@ for stochastic GNN training. It assumes that ...@@ -14,30 +14,33 @@ for stochastic GNN training. It assumes that
""" """
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import dgl import dgl
import torch
import numpy as np import numpy as np
import torch
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
dataset = DglNodePropPredDataset('ogbn-arxiv') dataset = DglNodePropPredDataset("ogbn-arxiv")
device = 'cpu' # change to 'cuda' for GPU device = "cpu" # change to 'cuda' for GPU
graph, node_labels = dataset[0] graph, node_labels = dataset[0]
# Add reverse edges since ogbn-arxiv is unidirectional. # Add reverse edges since ogbn-arxiv is unidirectional.
graph = dgl.add_reverse_edges(graph) graph = dgl.add_reverse_edges(graph)
graph.ndata['label'] = node_labels[:, 0] graph.ndata["label"] = node_labels[:, 0]
idx_split = dataset.get_idx_split() idx_split = dataset.get_idx_split()
train_nids = idx_split['train'] train_nids = idx_split["train"]
node_features = graph.ndata['feat'] node_features = graph.ndata["feat"]
sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4]) sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4])
train_dataloader = dgl.dataloading.DataLoader( train_dataloader = dgl.dataloading.DataLoader(
graph, train_nids, sampler, graph,
train_nids,
sampler,
batch_size=1024, batch_size=1024,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=0 num_workers=0,
) )
input_nodes, output_nodes, mfgs = next(iter(train_dataloader)) input_nodes, output_nodes, mfgs = next(iter(train_dataloader))
...@@ -75,8 +78,8 @@ print(mfg.num_src_nodes(), mfg.num_dst_nodes()) ...@@ -75,8 +78,8 @@ print(mfg.num_src_nodes(), mfg.num_dst_nodes())
# will do with ``ndata`` on the graphs you have seen earlier: # will do with ``ndata`` on the graphs you have seen earlier:
# #
mfg.srcdata['x'] = torch.zeros(mfg.num_src_nodes(), mfg.num_dst_nodes()) mfg.srcdata["x"] = torch.zeros(mfg.num_src_nodes(), mfg.num_dst_nodes())
dst_feat = mfg.dstdata['feat'] dst_feat = mfg.dstdata["feat"]
###################################################################### ######################################################################
...@@ -105,7 +108,11 @@ mfg.srcdata[dgl.NID], mfg.dstdata[dgl.NID] ...@@ -105,7 +108,11 @@ mfg.srcdata[dgl.NID], mfg.dstdata[dgl.NID]
# .. |image1| image:: https://data.dgl.ai/tutorial/img/bipartite.gif # .. |image1| image:: https://data.dgl.ai/tutorial/img/bipartite.gif
# #
print(torch.equal(mfg.srcdata[dgl.NID][:mfg.num_dst_nodes()], mfg.dstdata[dgl.NID])) print(
torch.equal(
mfg.srcdata[dgl.NID][: mfg.num_dst_nodes()], mfg.dstdata[dgl.NID]
)
)
###################################################################### ######################################################################
...@@ -113,7 +120,7 @@ print(torch.equal(mfg.srcdata[dgl.NID][:mfg.num_dst_nodes()], mfg.dstdata[dgl.NI ...@@ -113,7 +120,7 @@ print(torch.equal(mfg.srcdata[dgl.NID][:mfg.num_dst_nodes()], mfg.dstdata[dgl.NI
# :math:`h_u^{(l-1)}`: # :math:`h_u^{(l-1)}`:
# #
mfg.srcdata['h'] = torch.randn(mfg.num_src_nodes(), 10) mfg.srcdata["h"] = torch.randn(mfg.num_src_nodes(), 10)
###################################################################### ######################################################################
...@@ -132,8 +139,8 @@ mfg.srcdata['h'] = torch.randn(mfg.num_src_nodes(), 10) ...@@ -132,8 +139,8 @@ mfg.srcdata['h'] = torch.randn(mfg.num_src_nodes(), 10)
import dgl.function as fn import dgl.function as fn
mfg.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h')) mfg.update_all(message_func=fn.copy_u("h", "m"), reduce_func=fn.mean("m", "h"))
m_v = mfg.dstdata['h'] m_v = mfg.dstdata["h"]
m_v m_v
...@@ -147,6 +154,7 @@ import torch.nn as nn ...@@ -147,6 +154,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import tqdm import tqdm
class SAGEConv(nn.Module): class SAGEConv(nn.Module):
"""Graph convolution module used by the GraphSAGE model. """Graph convolution module used by the GraphSAGE model.
...@@ -157,6 +165,7 @@ class SAGEConv(nn.Module): ...@@ -157,6 +165,7 @@ class SAGEConv(nn.Module):
out_feat : int out_feat : int
Output feature size. Output feature size.
""" """
def __init__(self, in_feat, out_feat): def __init__(self, in_feat, out_feat):
super(SAGEConv, self).__init__() super(SAGEConv, self).__init__()
# A linear submodule for projecting the input and neighbor feature to the output. # A linear submodule for projecting the input and neighbor feature to the output.
...@@ -174,14 +183,15 @@ class SAGEConv(nn.Module): ...@@ -174,14 +183,15 @@ class SAGEConv(nn.Module):
""" """
with g.local_scope(): with g.local_scope():
h_src, h_dst = h h_src, h_dst = h
g.srcdata['h'] = h_src # <--- g.srcdata["h"] = h_src # <---
g.dstdata['h'] = h_dst # <--- g.dstdata["h"] = h_dst # <---
# update_all is a message passing API. # update_all is a message passing API.
g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_N')) g.update_all(fn.copy_u("h", "m"), fn.mean("m", "h_N"))
h_N = g.dstdata['h_N'] h_N = g.dstdata["h_N"]
h_total = torch.cat([h_dst, h_N], dim=1) # <--- h_total = torch.cat([h_dst, h_N], dim=1) # <---
return self.linear(h_total) return self.linear(h_total)
class Model(nn.Module): class Model(nn.Module):
def __init__(self, in_feats, h_feats, num_classes): def __init__(self, in_feats, h_feats, num_classes):
super(Model, self).__init__() super(Model, self).__init__()
...@@ -189,28 +199,31 @@ class Model(nn.Module): ...@@ -189,28 +199,31 @@ class Model(nn.Module):
self.conv2 = SAGEConv(h_feats, num_classes) self.conv2 = SAGEConv(h_feats, num_classes)
def forward(self, mfgs, x): def forward(self, mfgs, x):
h_dst = x[:mfgs[0].num_dst_nodes()] h_dst = x[: mfgs[0].num_dst_nodes()]
h = self.conv1(mfgs[0], (x, h_dst)) h = self.conv1(mfgs[0], (x, h_dst))
h = F.relu(h) h = F.relu(h)
h_dst = h[:mfgs[1].num_dst_nodes()] h_dst = h[: mfgs[1].num_dst_nodes()]
h = self.conv2(mfgs[1], (h, h_dst)) h = self.conv2(mfgs[1], (h, h_dst))
return h return h
sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4]) sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4])
train_dataloader = dgl.dataloading.DataLoader( train_dataloader = dgl.dataloading.DataLoader(
graph, train_nids, sampler, graph,
train_nids,
sampler,
device=device, device=device,
batch_size=1024, batch_size=1024,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=0 num_workers=0,
) )
model = Model(graph.ndata['feat'].shape[1], 128, dataset.num_classes).to(device) model = Model(graph.ndata["feat"].shape[1], 128, dataset.num_classes).to(device)
with tqdm.tqdm(train_dataloader) as tq: with tqdm.tqdm(train_dataloader) as tq:
for step, (input_nodes, output_nodes, mfgs) in enumerate(tq): for step, (input_nodes, output_nodes, mfgs) in enumerate(tq):
inputs = mfgs[0].srcdata['feat'] inputs = mfgs[0].srcdata["feat"]
labels = mfgs[-1].dstdata['label'] labels = mfgs[-1].dstdata["label"]
predictions = model(mfgs, inputs) predictions = model(mfgs, inputs)
...@@ -232,6 +245,7 @@ with tqdm.tqdm(train_dataloader) as tq: ...@@ -232,6 +245,7 @@ with tqdm.tqdm(train_dataloader) as tq:
# Say you start with a GNN module that works for full-graph training only: # Say you start with a GNN module that works for full-graph training only:
# #
class SAGEConv(nn.Module): class SAGEConv(nn.Module):
"""Graph convolution module used by the GraphSAGE model. """Graph convolution module used by the GraphSAGE model.
...@@ -242,6 +256,7 @@ class SAGEConv(nn.Module): ...@@ -242,6 +256,7 @@ class SAGEConv(nn.Module):
out_feat : int out_feat : int
Output feature size. Output feature size.
""" """
def __init__(self, in_feat, out_feat): def __init__(self, in_feat, out_feat):
super().__init__() super().__init__()
# A linear submodule for projecting the input and neighbor feature to the output. # A linear submodule for projecting the input and neighbor feature to the output.
...@@ -258,10 +273,13 @@ class SAGEConv(nn.Module): ...@@ -258,10 +273,13 @@ class SAGEConv(nn.Module):
The input node feature. The input node feature.
""" """
with g.local_scope(): with g.local_scope():
g.ndata['h'] = h g.ndata["h"] = h
# update_all is a message passing API. # update_all is a message passing API.
g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N')) g.update_all(
h_N = g.ndata['h_N'] message_func=fn.copy_u("h", "m"),
reduce_func=fn.mean("m", "h_N"),
)
h_N = g.ndata["h_N"]
h_total = torch.cat([h, h_N], dim=1) h_total = torch.cat([h, h_N], dim=1)
return self.linear(h_total) return self.linear(h_total)
...@@ -352,6 +370,7 @@ class SAGEConv(nn.Module): ...@@ -352,6 +370,7 @@ class SAGEConv(nn.Module):
# to something like the following: # to something like the following:
# #
class SAGEConvForBoth(nn.Module): class SAGEConvForBoth(nn.Module):
"""Graph convolution module used by the GraphSAGE model. """Graph convolution module used by the GraphSAGE model.
...@@ -362,6 +381,7 @@ class SAGEConvForBoth(nn.Module): ...@@ -362,6 +381,7 @@ class SAGEConvForBoth(nn.Module):
out_feat : int out_feat : int
Output feature size. Output feature size.
""" """
def __init__(self, in_feat, out_feat): def __init__(self, in_feat, out_feat):
super().__init__() super().__init__()
# A linear submodule for projecting the input and neighbor feature to the output. # A linear submodule for projecting the input and neighbor feature to the output.
...@@ -383,10 +403,13 @@ class SAGEConvForBoth(nn.Module): ...@@ -383,10 +403,13 @@ class SAGEConvForBoth(nn.Module):
else: else:
h_src = h_dst = h h_src = h_dst = h
g.srcdata['h'] = h_src g.srcdata["h"] = h_src
# update_all is a message passing API. # update_all is a message passing API.
g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N')) g.update_all(
h_N = g.ndata['h_N'] message_func=fn.copy_u("h", "m"),
reduce_func=fn.mean("m", "h_N"),
)
h_N = g.ndata["h_N"]
h_total = torch.cat([h_dst, h_N], dim=1) h_total = torch.cat([h_dst, h_N], dim=1)
return self.linear(h_total) return self.linear(h_total)
......
...@@ -46,13 +46,13 @@ message passing APIs. ...@@ -46,13 +46,13 @@ message passing APIs.
# representations :math:`h_v`, we can simply use builtin functions: # representations :math:`h_v`, we can simply use builtin functions:
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
import torch as th
import torch.nn as nn
import torch.nn.functional as F
os.environ["DGLBACKEND"] = "pytorch"
import dgl import dgl
import dgl.function as fn import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph from dgl import DGLGraph
gcn_msg = fn.copy_u(u="h", out="m") gcn_msg = fn.copy_u(u="h", out="m")
...@@ -156,7 +156,6 @@ dur = [] ...@@ -156,7 +156,6 @@ dur = []
for epoch in range(50): for epoch in range(50):
if epoch >= 3: if epoch >= 3:
t0 = time.time() t0 = time.time()
net.train() net.train()
logits = net(g, features) logits = net(g, features)
logp = F.log_softmax(logits, 1) logp = F.log_softmax(logits, 1)
...@@ -168,14 +167,12 @@ for epoch in range(50): ...@@ -168,14 +167,12 @@ for epoch in range(50):
if epoch >= 3: if epoch >= 3:
dur.append(time.time() - t0) dur.append(time.time() - t0)
acc = evaluate(net, g, features, labels, test_mask) acc = evaluate(net, g, features, labels, test_mask)
print( print(
"Epoch {:05d} | Loss {:.4f} | Test Acc {:.4f} | Time(s) {:.4f}".format( "Epoch {:05d} | Loss {:.4f} | Test Acc {:.4f} | Time(s) {:.4f}".format(
epoch, loss.item(), acc, np.mean(dur) epoch, loss.item(), acc, np.mean(dur)
) )
) )
############################################################################### ###############################################################################
# .. _math: # .. _math:
# #
......
...@@ -137,18 +137,29 @@ multiple edges among any given pair. ...@@ -137,18 +137,29 @@ multiple edges among any given pair.
# #
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
from functools import partial
import dgl import dgl
import dgl.function as fn
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl import DGLGraph from dgl import DGLGraph
import dgl.function as fn
from functools import partial
class RGCNLayer(nn.Module): class RGCNLayer(nn.Module):
def __init__(self, in_feat, out_feat, num_rels, num_bases=-1, bias=None, def __init__(
activation=None, is_input_layer=False): self,
in_feat,
out_feat,
num_rels,
num_bases=-1,
bias=None,
activation=None,
is_input_layer=False,
):
super(RGCNLayer, self).__init__() super(RGCNLayer, self).__init__()
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
...@@ -161,69 +172,85 @@ class RGCNLayer(nn.Module): ...@@ -161,69 +172,85 @@ class RGCNLayer(nn.Module):
# sanity check # sanity check
if self.num_bases <= 0 or self.num_bases > self.num_rels: if self.num_bases <= 0 or self.num_bases > self.num_rels:
self.num_bases = self.num_rels self.num_bases = self.num_rels
# weight bases in equation (3) # weight bases in equation (3)
self.weight = nn.Parameter(torch.Tensor(self.num_bases, self.in_feat, self.weight = nn.Parameter(
self.out_feat)) torch.Tensor(self.num_bases, self.in_feat, self.out_feat)
)
if self.num_bases < self.num_rels: if self.num_bases < self.num_rels:
# linear combination coefficients in equation (3) # linear combination coefficients in equation (3)
self.w_comp = nn.Parameter(torch.Tensor(self.num_rels, self.num_bases)) self.w_comp = nn.Parameter(
torch.Tensor(self.num_rels, self.num_bases)
)
# add bias # add bias
if self.bias: if self.bias:
self.bias = nn.Parameter(torch.Tensor(out_feat)) self.bias = nn.Parameter(torch.Tensor(out_feat))
# init trainable parameters # init trainable parameters
nn.init.xavier_uniform_(self.weight, nn.init.xavier_uniform_(
gain=nn.init.calculate_gain('relu')) self.weight, gain=nn.init.calculate_gain("relu")
)
if self.num_bases < self.num_rels: if self.num_bases < self.num_rels:
nn.init.xavier_uniform_(self.w_comp, nn.init.xavier_uniform_(
gain=nn.init.calculate_gain('relu')) self.w_comp, gain=nn.init.calculate_gain("relu")
)
if self.bias: if self.bias:
nn.init.xavier_uniform_(self.bias, nn.init.xavier_uniform_(
gain=nn.init.calculate_gain('relu')) self.bias, gain=nn.init.calculate_gain("relu")
)
def forward(self, g): def forward(self, g):
if self.num_bases < self.num_rels: if self.num_bases < self.num_rels:
# generate all weights from bases (equation (3)) # generate all weights from bases (equation (3))
weight = self.weight.view(self.in_feat, self.num_bases, self.out_feat) weight = self.weight.view(
weight = torch.matmul(self.w_comp, weight).view(self.num_rels, self.in_feat, self.num_bases, self.out_feat
self.in_feat, self.out_feat) )
weight = torch.matmul(self.w_comp, weight).view(
self.num_rels, self.in_feat, self.out_feat
)
else: else:
weight = self.weight weight = self.weight
if self.is_input_layer: if self.is_input_layer:
def message_func(edges): def message_func(edges):
# for input layer, matrix multiply can be converted to be # for input layer, matrix multiply can be converted to be
# an embedding lookup using source node id # an embedding lookup using source node id
embed = weight.view(-1, self.out_feat) embed = weight.view(-1, self.out_feat)
index = edges.data[dgl.ETYPE] * self.in_feat + edges.src['id'] index = edges.data[dgl.ETYPE] * self.in_feat + edges.src["id"]
return {'msg': embed[index] * edges.data['norm']} return {"msg": embed[index] * edges.data["norm"]}
else: else:
def message_func(edges): def message_func(edges):
w = weight[edges.data[dgl.ETYPE]] w = weight[edges.data[dgl.ETYPE]]
msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze() msg = torch.bmm(edges.src["h"].unsqueeze(1), w).squeeze()
msg = msg * edges.data['norm'] msg = msg * edges.data["norm"]
return {'msg': msg} return {"msg": msg}
def apply_func(nodes): def apply_func(nodes):
h = nodes.data['h'] h = nodes.data["h"]
if self.bias: if self.bias:
h = h + self.bias h = h + self.bias
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
return {'h': h} return {"h": h}
g.update_all(message_func, fn.sum(msg='msg', out='h'), apply_func) g.update_all(message_func, fn.sum(msg="msg", out="h"), apply_func)
############################################################################### ###############################################################################
# Full R-GCN model defined # Full R-GCN model defined
# ~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~
class Model(nn.Module): class Model(nn.Module):
def __init__(self, num_nodes, h_dim, out_dim, num_rels, def __init__(
num_bases=-1, num_hidden_layers=1): self,
num_nodes,
h_dim,
out_dim,
num_rels,
num_bases=-1,
num_hidden_layers=1,
):
super(Model, self).__init__() super(Model, self).__init__()
self.num_nodes = num_nodes self.num_nodes = num_nodes
self.h_dim = h_dim self.h_dim = h_dim
...@@ -257,23 +284,40 @@ class Model(nn.Module): ...@@ -257,23 +284,40 @@ class Model(nn.Module):
return features return features
def build_input_layer(self): def build_input_layer(self):
return RGCNLayer(self.num_nodes, self.h_dim, self.num_rels, self.num_bases, return RGCNLayer(
activation=F.relu, is_input_layer=True) self.num_nodes,
self.h_dim,
self.num_rels,
self.num_bases,
activation=F.relu,
is_input_layer=True,
)
def build_hidden_layer(self): def build_hidden_layer(self):
return RGCNLayer(self.h_dim, self.h_dim, self.num_rels, self.num_bases, return RGCNLayer(
activation=F.relu) self.h_dim,
self.h_dim,
self.num_rels,
self.num_bases,
activation=F.relu,
)
def build_output_layer(self): def build_output_layer(self):
return RGCNLayer(self.h_dim, self.out_dim, self.num_rels, self.num_bases, return RGCNLayer(
activation=partial(F.softmax, dim=1)) self.h_dim,
self.out_dim,
self.num_rels,
self.num_bases,
activation=partial(F.softmax, dim=1),
)
def forward(self, g): def forward(self, g):
if self.features is not None: if self.features is not None:
g.ndata['id'] = self.features g.ndata["id"] = self.features
for layer in self.layers: for layer in self.layers:
layer(g) layer(g)
return g.ndata.pop('h') return g.ndata.pop("h")
############################################################################### ###############################################################################
# Handle dataset # Handle dataset
...@@ -284,16 +328,16 @@ class Model(nn.Module): ...@@ -284,16 +328,16 @@ class Model(nn.Module):
dataset = dgl.data.rdf.AIFBDataset() dataset = dgl.data.rdf.AIFBDataset()
g = dataset[0] g = dataset[0]
category = dataset.predict_category category = dataset.predict_category
train_mask = g.nodes[category].data.pop('train_mask') train_mask = g.nodes[category].data.pop("train_mask")
test_mask = g.nodes[category].data.pop('test_mask') test_mask = g.nodes[category].data.pop("test_mask")
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze() train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze() test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
labels = g.nodes[category].data.pop('label') labels = g.nodes[category].data.pop("label")
num_rels = len(g.canonical_etypes) num_rels = len(g.canonical_etypes)
num_classes = dataset.num_classes num_classes = dataset.num_classes
# normalization factor # normalization factor
for cetype in g.canonical_etypes: for cetype in g.canonical_etypes:
g.edges[cetype].data['norm'] = dgl.norm_by_dst(g, cetype).unsqueeze(1) g.edges[cetype].data["norm"] = dgl.norm_by_dst(g, cetype).unsqueeze(1)
category_id = g.ntypes.index(category) category_id = g.ntypes.index(category)
############################################################################### ###############################################################################
...@@ -309,17 +353,19 @@ lr = 0.01 # learning rate ...@@ -309,17 +353,19 @@ lr = 0.01 # learning rate
l2norm = 0 # L2 norm coefficient l2norm = 0 # L2 norm coefficient
# create graph # create graph
g = dgl.to_homogeneous(g, edata=['norm']) g = dgl.to_homogeneous(g, edata=["norm"])
node_ids = torch.arange(g.num_nodes()) node_ids = torch.arange(g.num_nodes())
target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id] target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
# create model # create model
model = Model(g.num_nodes(), model = Model(
g.num_nodes(),
n_hidden, n_hidden,
num_classes, num_classes,
num_rels, num_rels,
num_bases=n_bases, num_bases=n_bases,
num_hidden_layers=n_hidden_layers) num_hidden_layers=n_hidden_layers,
)
############################################################################### ###############################################################################
# Training loop # Training loop
...@@ -344,12 +390,15 @@ for epoch in range(n_epochs): ...@@ -344,12 +390,15 @@ for epoch in range(n_epochs):
val_loss = F.cross_entropy(logits[test_idx], labels[test_idx]) val_loss = F.cross_entropy(logits[test_idx], labels[test_idx])
val_acc = torch.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]) val_acc = torch.sum(logits[test_idx].argmax(dim=1) == labels[test_idx])
val_acc = val_acc.item() / len(test_idx) val_acc = val_acc.item() / len(test_idx)
print("Epoch {:05d} | ".format(epoch) + print(
"Train Accuracy: {:.4f} | Train Loss: {:.4f} | ".format( "Epoch {:05d} | ".format(epoch)
train_acc, loss.item()) + + "Train Accuracy: {:.4f} | Train Loss: {:.4f} | ".format(
"Validation Accuracy: {:.4f} | Validation loss: {:.4f}".format( train_acc, loss.item()
val_acc, val_loss.item())) )
+ "Validation Accuracy: {:.4f} | Validation loss: {:.4f}".format(
val_acc, val_loss.item()
)
)
############################################################################### ###############################################################################
# .. _link-prediction: # .. _link-prediction:
# #
......
...@@ -87,19 +87,19 @@ Line Graph Neural Network ...@@ -87,19 +87,19 @@ Line Graph Neural Network
# than inter-class. # than inter-class.
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import torch import torch
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
from dgl.data import citation_graph as citegrh from dgl.data import citation_graph as citegrh
data = citegrh.load_cora() data = citegrh.load_cora()
G = data[0] G = data[0]
labels = th.tensor(G.ndata['label']) labels = th.tensor(G.ndata["label"])
# find all the nodes labeled with class 0 # find all the nodes labeled with class 0
label0_nodes = th.nonzero(labels == 0, as_tuple=False).squeeze() label0_nodes = th.nonzero(labels == 0, as_tuple=False).squeeze()
...@@ -108,7 +108,9 @@ src, _ = G.in_edges(label0_nodes) ...@@ -108,7 +108,9 @@ src, _ = G.in_edges(label0_nodes)
src_labels = labels[src] src_labels = labels[src]
# find all the edges whose both endpoints are in class 0 # find all the edges whose both endpoints are in class 0
intra_src = th.nonzero(src_labels == 0, as_tuple=False) intra_src = th.nonzero(src_labels == 0, as_tuple=False)
print('Intra-class edges percent: %.4f' % (len(intra_src) / len(src_labels))) print("Intra-class edges percent: %.4f" % (len(intra_src) / len(src_labels)))
import matplotlib.pyplot as plt
########################################################################################### ###########################################################################################
# Binary community subgraph from Cora with a test dataset # Binary community subgraph from Cora with a test dataset
...@@ -127,19 +129,30 @@ print('Intra-class edges percent: %.4f' % (len(intra_src) / len(src_labels))) ...@@ -127,19 +129,30 @@ print('Intra-class edges percent: %.4f' % (len(intra_src) / len(src_labels)))
# With the following code, you can visualize one of the training samples and its community structure. # With the following code, you can visualize one of the training samples and its community structure.
import networkx as nx import networkx as nx
import matplotlib.pyplot as plt
train_set = dgl.data.CoraBinary() train_set = dgl.data.CoraBinary()
G1, pmpd1, label1 = train_set[1] G1, pmpd1, label1 = train_set[1]
nx_G1 = G1.to_networkx() nx_G1 = G1.to_networkx()
def visualize(labels, g): def visualize(labels, g):
pos = nx.spring_layout(g, seed=1) pos = nx.spring_layout(g, seed=1)
plt.figure(figsize=(8, 8)) plt.figure(figsize=(8, 8))
plt.axis('off') plt.axis("off")
nx.draw_networkx(g, pos=pos, node_size=50, cmap=plt.get_cmap('coolwarm'), nx.draw_networkx(
node_color=labels, edge_color='k', g,
arrows=False, width=0.5, style='dotted', with_labels=False) pos=pos,
node_size=50,
cmap=plt.get_cmap("coolwarm"),
node_color=labels,
edge_color="k",
arrows=False,
width=0.5,
style="dotted",
with_labels=False,
)
visualize(label1, nx_G1) visualize(label1, nx_G1)
########################################################################################### ###########################################################################################
...@@ -369,20 +382,23 @@ visualize(label1, nx_G1) ...@@ -369,20 +382,23 @@ visualize(label1, nx_G1)
# Return a list containing features gathered from multiple radius. # Return a list containing features gathered from multiple radius.
import dgl.function as fn import dgl.function as fn
def aggregate_radius(radius, g, z): def aggregate_radius(radius, g, z):
# initializing list to collect message passing result # initializing list to collect message passing result
z_list = [] z_list = []
g.ndata['z'] = z g.ndata["z"] = z
# pulling message from 1-hop neighbourhood # pulling message from 1-hop neighbourhood
g.update_all(fn.copy_u(u='z', out='m'), fn.sum(msg='m', out='z')) g.update_all(fn.copy_u(u="z", out="m"), fn.sum(msg="m", out="z"))
z_list.append(g.ndata['z']) z_list.append(g.ndata["z"])
for i in range(radius - 1): for i in range(radius - 1):
for j in range(2 ** i): for j in range(2**i):
#pulling message from 2^j neighborhood # pulling message from 2^j neighborhood
g.update_all(fn.copy_u(u='z', out='m'), fn.sum(msg='m', out='z')) g.update_all(fn.copy_u(u="z", out="m"), fn.sum(msg="m", out="z"))
z_list.append(g.ndata['z']) z_list.append(g.ndata["z"])
return z_list return z_list
######################################################################### #########################################################################
# Implementing :math:`\text{fuse}` as sparse matrix multiplication # Implementing :math:`\text{fuse}` as sparse matrix multiplication
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -428,7 +444,8 @@ class LGNNCore(nn.Module): ...@@ -428,7 +444,8 @@ class LGNNCore(nn.Module):
self.linear_prev = nn.Linear(in_feats, out_feats) self.linear_prev = nn.Linear(in_feats, out_feats)
self.linear_deg = nn.Linear(in_feats, out_feats) self.linear_deg = nn.Linear(in_feats, out_feats)
self.linear_radius = nn.ModuleList( self.linear_radius = nn.ModuleList(
[nn.Linear(in_feats, out_feats) for i in range(radius)]) [nn.Linear(in_feats, out_feats) for i in range(radius)]
)
self.linear_fuse = nn.Linear(in_feats, out_feats) self.linear_fuse = nn.Linear(in_feats, out_feats)
self.bn = nn.BatchNorm1d(out_feats) self.bn = nn.BatchNorm1d(out_feats)
...@@ -442,7 +459,9 @@ class LGNNCore(nn.Module): ...@@ -442,7 +459,9 @@ class LGNNCore(nn.Module):
# aggregate 2^j-hop features # aggregate 2^j-hop features
hop2j_list = aggregate_radius(self.radius, g, feat_a) hop2j_list = aggregate_radius(self.radius, g, feat_a)
# apply linear transformation # apply linear transformation
hop2j_list = [linear(x) for linear, x in zip(self.linear_radius, hop2j_list)] hop2j_list = [
linear(x) for linear, x in zip(self.linear_radius, hop2j_list)
]
radius_proj = sum(hop2j_list) radius_proj = sum(hop2j_list)
# term "fuse" # term "fuse"
...@@ -458,6 +477,7 @@ class LGNNCore(nn.Module): ...@@ -458,6 +477,7 @@ class LGNNCore(nn.Module):
return result return result
############################################################################################################## ##############################################################################################################
# Chain-up LGNN abstractions as an LGNN layer # Chain-up LGNN abstractions as an LGNN layer
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -482,6 +502,7 @@ class LGNNLayer(nn.Module): ...@@ -482,6 +502,7 @@ class LGNNLayer(nn.Module):
next_lg_x = self.lg_layer(lg, lg_x, x, deg_lg, pm_pd_y) next_lg_x = self.lg_layer(lg, lg_x, x, deg_lg, pm_pd_y)
return next_x, next_lg_x return next_x, next_lg_x
######################################################################################## ########################################################################################
# Chain-up LGNN layers # Chain-up LGNN layers
# ~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~
...@@ -504,15 +525,17 @@ class LGNN(nn.Module): ...@@ -504,15 +525,17 @@ class LGNN(nn.Module):
x, lg_x = self.layer2(g, lg, x, lg_x, deg_g, deg_lg, pm_pd) x, lg_x = self.layer2(g, lg, x, lg_x, deg_g, deg_lg, pm_pd)
x, lg_x = self.layer3(g, lg, x, lg_x, deg_g, deg_lg, pm_pd) x, lg_x = self.layer3(g, lg, x, lg_x, deg_g, deg_lg, pm_pd)
return self.linear(x) return self.linear(x)
######################################################################################### #########################################################################################
# Training and inference # Training and inference
# ----------------------- # -----------------------
# First load the data. # First load the data.
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
training_loader = DataLoader(train_set,
batch_size=1, training_loader = DataLoader(
collate_fn=train_set.collate_fn, train_set, batch_size=1, collate_fn=train_set.collate_fn, drop_last=True
drop_last=True) )
####################################################################################### #######################################################################################
# Next, define the main training loop. Note that each training sample contains # Next, define the main training loop. Note that each training sample contains
...@@ -536,9 +559,12 @@ optimizer = th.optim.Adam(model.parameters(), lr=1e-2) ...@@ -536,9 +559,12 @@ optimizer = th.optim.Adam(model.parameters(), lr=1e-2)
def sparse2th(mat): def sparse2th(mat):
value = mat.data value = mat.data
indices = th.LongTensor([mat.row, mat.col]) indices = th.LongTensor([mat.row, mat.col])
tensor = th.sparse.FloatTensor(indices, th.from_numpy(value).float(), mat.shape) tensor = th.sparse.FloatTensor(
indices, th.from_numpy(value).float(), mat.shape
)
return tensor return tensor
# Train for 20 epochs # Train for 20 epochs
for i in range(20): for i in range(20):
all_loss = [] all_loss = []
...@@ -571,11 +597,11 @@ for i in range(20): ...@@ -571,11 +597,11 @@ for i in range(20):
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
niters = len(all_loss) niters = len(all_loss)
print("Epoch %d | loss %.4f | accuracy %.4f" % (i, print(
sum(all_loss) / niters, sum(all_acc) / niters)) "Epoch %d | loss %.4f | accuracy %.4f"
% (i, sum(all_loss) / niters, sum(all_acc) / niters)
)
####################################################################################### #######################################################################################
# Visualize training progress # Visualize training progress
# ----------------------------- # -----------------------------
...@@ -613,6 +639,7 @@ visualize(label1, nx_G1) ...@@ -613,6 +639,7 @@ visualize(label1, nx_G1)
# :math`\{Pm,Pd\}` as block diagonal matrix in correspondence to DGL batched # :math`\{Pm,Pd\}` as block diagonal matrix in correspondence to DGL batched
# graph API. # graph API.
def collate_fn(batch): def collate_fn(batch):
graphs, pmpds, labels = zip(*batch) graphs, pmpds, labels = zip(*batch)
batched_graphs = dgl.batch(graphs) batched_graphs = dgl.batch(graphs)
...@@ -620,6 +647,7 @@ def collate_fn(batch): ...@@ -620,6 +647,7 @@ def collate_fn(batch):
batched_labels = np.concatenate(labels, axis=0) batched_labels = np.concatenate(labels, axis=0)
return batched_graphs, batched_pmpds, batched_labels return batched_graphs, batched_pmpds, batched_labels
###################################################################################### ######################################################################################
# You can find the complete code on Github at # You can find the complete code on Github at
# `Community Detection with Graph Neural Networks (CDGNN) <https://github.com/dmlc/dgl/tree/master/examples/pytorch/line_graph>`_. # `Community Detection with Graph Neural Networks (CDGNN) <https://github.com/dmlc/dgl/tree/master/examples/pytorch/line_graph>`_.
...@@ -105,9 +105,8 @@ structure-free normalization, in the style of attention. ...@@ -105,9 +105,8 @@ structure-free normalization, in the style of attention.
# subpackage. Simply import the ``GATConv`` as the follows. # subpackage. Simply import the ``GATConv`` as the follows.
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
from dgl.nn.pytorch import GATConv
os.environ["DGLBACKEND"] = "pytorch"
############################################################### ###############################################################
# Readers can skip the following step-by-step explanation of the implementation and # Readers can skip the following step-by-step explanation of the implementation and
# jump to the `Put everything together`_ for training and visualization results. # jump to the `Put everything together`_ for training and visualization results.
...@@ -125,6 +124,7 @@ from dgl.nn.pytorch import GATConv ...@@ -125,6 +124,7 @@ from dgl.nn.pytorch import GATConv
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.nn.pytorch import GATConv
class GATLayer(nn.Module): class GATLayer(nn.Module):
...@@ -139,37 +139,38 @@ class GATLayer(nn.Module): ...@@ -139,37 +139,38 @@ class GATLayer(nn.Module):
def reset_parameters(self): def reset_parameters(self):
"""Reinitialize learnable parameters.""" """Reinitialize learnable parameters."""
gain = nn.init.calculate_gain('relu') gain = nn.init.calculate_gain("relu")
nn.init.xavier_normal_(self.fc.weight, gain=gain) nn.init.xavier_normal_(self.fc.weight, gain=gain)
nn.init.xavier_normal_(self.attn_fc.weight, gain=gain) nn.init.xavier_normal_(self.attn_fc.weight, gain=gain)
def edge_attention(self, edges): def edge_attention(self, edges):
# edge UDF for equation (2) # edge UDF for equation (2)
z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1) z2 = torch.cat([edges.src["z"], edges.dst["z"]], dim=1)
a = self.attn_fc(z2) a = self.attn_fc(z2)
return {'e': F.leaky_relu(a)} return {"e": F.leaky_relu(a)}
def message_func(self, edges): def message_func(self, edges):
# message UDF for equation (3) & (4) # message UDF for equation (3) & (4)
return {'z': edges.src['z'], 'e': edges.data['e']} return {"z": edges.src["z"], "e": edges.data["e"]}
def reduce_func(self, nodes): def reduce_func(self, nodes):
# reduce UDF for equation (3) & (4) # reduce UDF for equation (3) & (4)
# equation (3) # equation (3)
alpha = F.softmax(nodes.mailbox['e'], dim=1) alpha = F.softmax(nodes.mailbox["e"], dim=1)
# equation (4) # equation (4)
h = torch.sum(alpha * nodes.mailbox['z'], dim=1) h = torch.sum(alpha * nodes.mailbox["z"], dim=1)
return {'h': h} return {"h": h}
def forward(self, h): def forward(self, h):
# equation (1) # equation (1)
z = self.fc(h) z = self.fc(h)
self.g.ndata['z'] = z self.g.ndata["z"] = z
# equation (2) # equation (2)
self.g.apply_edges(self.edge_attention) self.g.apply_edges(self.edge_attention)
# equation (3) & (4) # equation (3) & (4)
self.g.update_all(self.message_func, self.reduce_func) self.g.update_all(self.message_func, self.reduce_func)
return self.g.ndata.pop('h') return self.g.ndata.pop("h")
################################################################## ##################################################################
# Equation (1) # Equation (1)
...@@ -195,11 +196,13 @@ class GATLayer(nn.Module): ...@@ -195,11 +196,13 @@ class GATLayer(nn.Module):
# ``apply_edges`` API. The argument to the ``apply_edges`` is an **Edge UDF**, # ``apply_edges`` API. The argument to the ``apply_edges`` is an **Edge UDF**,
# which is defined as below: # which is defined as below:
def edge_attention(self, edges): def edge_attention(self, edges):
# edge UDF for equation (2) # edge UDF for equation (2)
z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1) z2 = torch.cat([edges.src["z"], edges.dst["z"]], dim=1)
a = self.attn_fc(z2) a = self.attn_fc(z2)
return {'e' : F.leaky_relu(a)} return {"e": F.leaky_relu(a)}
########################################################################3 ########################################################################3
# Here, the dot product with the learnable weight vector :math:`\vec{a^{(l)}}` # Here, the dot product with the learnable weight vector :math:`\vec{a^{(l)}}`
...@@ -229,13 +232,15 @@ def edge_attention(self, edges): ...@@ -229,13 +232,15 @@ def edge_attention(self, edges):
# Both tasks first fetch data from the mailbox and then manipulate it on the # Both tasks first fetch data from the mailbox and then manipulate it on the
# second dimension (``dim=1``), on which the messages are batched. # second dimension (``dim=1``), on which the messages are batched.
def reduce_func(self, nodes): def reduce_func(self, nodes):
# reduce UDF for equation (3) & (4) # reduce UDF for equation (3) & (4)
# equation (3) # equation (3)
alpha = F.softmax(nodes.mailbox['e'], dim=1) alpha = F.softmax(nodes.mailbox["e"], dim=1)
# equation (4) # equation (4)
h = torch.sum(alpha * nodes.mailbox['z'], dim=1) h = torch.sum(alpha * nodes.mailbox["z"], dim=1)
return {'h' : h} return {"h": h}
##################################################################### #####################################################################
# Multi-head attention # Multi-head attention
...@@ -258,8 +263,9 @@ def reduce_func(self, nodes): ...@@ -258,8 +263,9 @@ def reduce_func(self, nodes):
# Use the above defined single-head ``GATLayer`` as the building block # Use the above defined single-head ``GATLayer`` as the building block
# for the ``MultiHeadGATLayer`` below: # for the ``MultiHeadGATLayer`` below:
class MultiHeadGATLayer(nn.Module): class MultiHeadGATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'): def __init__(self, g, in_dim, out_dim, num_heads, merge="cat"):
super(MultiHeadGATLayer, self).__init__() super(MultiHeadGATLayer, self).__init__()
self.heads = nn.ModuleList() self.heads = nn.ModuleList()
for i in range(num_heads): for i in range(num_heads):
...@@ -268,19 +274,21 @@ class MultiHeadGATLayer(nn.Module): ...@@ -268,19 +274,21 @@ class MultiHeadGATLayer(nn.Module):
def forward(self, h): def forward(self, h):
head_outs = [attn_head(h) for attn_head in self.heads] head_outs = [attn_head(h) for attn_head in self.heads]
if self.merge == 'cat': if self.merge == "cat":
# concat on the output feature dimension (dim=1) # concat on the output feature dimension (dim=1)
return torch.cat(head_outs, dim=1) return torch.cat(head_outs, dim=1)
else: else:
# merge using average # merge using average
return torch.mean(torch.stack(head_outs)) return torch.mean(torch.stack(head_outs))
########################################################################### ###########################################################################
# Put everything together # Put everything together
# ^^^^^^^^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^^^^^^^^
# #
# Now, you can define a two-layer GAT model. # Now, you can define a two-layer GAT model.
class GAT(nn.Module): class GAT(nn.Module):
def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads): def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
super(GAT, self).__init__() super(GAT, self).__init__()
...@@ -296,33 +304,34 @@ class GAT(nn.Module): ...@@ -296,33 +304,34 @@ class GAT(nn.Module):
h = self.layer2(h) h = self.layer2(h)
return h return h
import networkx as nx
############################################################################# #############################################################################
# We then load the Cora dataset using DGL's built-in data module. # We then load the Cora dataset using DGL's built-in data module.
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import citation_graph as citegrh from dgl.data import citation_graph as citegrh
import networkx as nx
def load_cora_data(): def load_cora_data():
data = citegrh.load_cora() data = citegrh.load_cora()
g = data[0] g = data[0]
mask = torch.BoolTensor(g.ndata['train_mask']) mask = torch.BoolTensor(g.ndata["train_mask"])
return g, g.ndata['feat'], g.ndata['label'], mask return g, g.ndata["feat"], g.ndata["label"], mask
############################################################################## ##############################################################################
# The training loop is exactly the same as in the GCN tutorial. # The training loop is exactly the same as in the GCN tutorial.
import time import time
import numpy as np import numpy as np
g, features, labels, mask = load_cora_data() g, features, labels, mask = load_cora_data()
# create the model, 2 heads, each head has hidden size 8 # create the model, 2 heads, each head has hidden size 8
net = GAT(g, net = GAT(g, in_dim=features.size()[1], hidden_dim=8, out_dim=7, num_heads=2)
in_dim=features.size()[1],
hidden_dim=8,
out_dim=7,
num_heads=2)
# create optimizer # create optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3) optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
...@@ -344,8 +353,11 @@ for epoch in range(30): ...@@ -344,8 +353,11 @@ for epoch in range(30):
if epoch >= 3: if epoch >= 3:
dur.append(time.time() - t0) dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format( print(
epoch, loss.item(), np.mean(dur))) "Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
epoch, loss.item(), np.mean(dur)
)
)
######################################################################### #########################################################################
# Visualizing and understanding attention learned # Visualizing and understanding attention learned
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment