Unverified Commit 98325b10 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4691)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent c24e285a
"""Views of DGLGraph.""" """Views of DGLGraph."""
from __future__ import absolute_import from __future__ import absolute_import
from collections import namedtuple, defaultdict from collections import defaultdict, namedtuple
from collections.abc import MutableMapping from collections.abc import MutableMapping
from .base import ALL, DGLError
from . import backend as F from . import backend as F
from .base import ALL, DGLError
from .frame import LazyFeature from .frame import LazyFeature
NodeSpace = namedtuple('NodeSpace', ['data']) NodeSpace = namedtuple("NodeSpace", ["data"])
EdgeSpace = namedtuple('EdgeSpace', ['data']) EdgeSpace = namedtuple("EdgeSpace", ["data"])
class HeteroNodeView(object): class HeteroNodeView(object):
"""A NodeView class to act as G.nodes for a DGLHeteroGraph.""" """A NodeView class to act as G.nodes for a DGLHeteroGraph."""
__slots__ = ['_graph', '_typeid_getter']
__slots__ = ["_graph", "_typeid_getter"]
def __init__(self, graph, typeid_getter): def __init__(self, graph, typeid_getter):
self._graph = graph self._graph = graph
...@@ -23,8 +24,9 @@ class HeteroNodeView(object): ...@@ -23,8 +24,9 @@ class HeteroNodeView(object):
def __getitem__(self, key): def __getitem__(self, key):
if isinstance(key, slice): if isinstance(key, slice):
# slice # slice
if not (key.start is None and key.stop is None if not (
and key.step is None): key.start is None and key.stop is None and key.step is None
):
raise DGLError('Currently only full slice ":" is supported') raise DGLError('Currently only full slice ":" is supported')
nodes = ALL nodes = ALL
ntype = None ntype = None
...@@ -38,20 +40,25 @@ class HeteroNodeView(object): ...@@ -38,20 +40,25 @@ class HeteroNodeView(object):
ntype = None ntype = None
ntid = self._typeid_getter(ntype) ntid = self._typeid_getter(ntype)
return NodeSpace( return NodeSpace(
data=HeteroNodeDataView( data=HeteroNodeDataView(self._graph, ntype, ntid, nodes)
self._graph, ntype, ntid, nodes)) )
def __call__(self, ntype=None): def __call__(self, ntype=None):
"""Return the nodes.""" """Return the nodes."""
ntid = self._typeid_getter(ntype) ntid = self._typeid_getter(ntype)
ret = F.arange(0, self._graph._graph.number_of_nodes(ntid), ret = F.arange(
dtype=self._graph.idtype, ctx=self._graph.device) 0,
self._graph._graph.number_of_nodes(ntid),
dtype=self._graph.idtype,
ctx=self._graph.device,
)
return ret return ret
class HeteroNodeDataView(MutableMapping): class HeteroNodeDataView(MutableMapping):
"""The data view class when G.ndata[ntype] is called.""" """The data view class when G.ndata[ntype] is called."""
__slots__ = ['_graph', '_ntype', '_ntid', '_nodes']
__slots__ = ["_graph", "_ntype", "_ntid", "_nodes"]
def __init__(self, graph, ntype, ntid, nodes): def __init__(self, graph, ntype, ntid, nodes):
self._graph = graph self._graph = graph
...@@ -63,9 +70,9 @@ class HeteroNodeDataView(MutableMapping): ...@@ -63,9 +70,9 @@ class HeteroNodeDataView(MutableMapping):
if isinstance(self._ntype, list): if isinstance(self._ntype, list):
ret = {} ret = {}
for (i, ntype) in enumerate(self._ntype): for (i, ntype) in enumerate(self._ntype):
value = self._graph._get_n_repr( value = self._graph._get_n_repr(self._ntid[i], self._nodes).get(
self._ntid[i], self._nodes).get( key, None
key, None) )
if value is not None: if value is not None:
ret[ntype] = value ret[ntype] = value
return ret return ret
...@@ -76,17 +83,19 @@ class HeteroNodeDataView(MutableMapping): ...@@ -76,17 +83,19 @@ class HeteroNodeDataView(MutableMapping):
if isinstance(val, LazyFeature): if isinstance(val, LazyFeature):
self._graph._node_frames[self._ntid][key] = val self._graph._node_frames[self._ntid][key] = val
elif isinstance(self._ntype, list): elif isinstance(self._ntype, list):
assert isinstance(val, dict), \ assert isinstance(val, dict), (
'Current HeteroNodeDataView has multiple node types, ' \ "Current HeteroNodeDataView has multiple node types, "
'please passing the node type and the corresponding data through a dict.' "please passing the node type and the corresponding data through a dict."
)
for (ntype, data) in val.items(): for (ntype, data) in val.items():
ntid = self._graph.get_ntype_id(ntype) ntid = self._graph.get_ntype_id(ntype)
self._graph._set_n_repr(ntid, self._nodes, {key: data}) self._graph._set_n_repr(ntid, self._nodes, {key: data})
else: else:
assert isinstance(val, dict) is False, \ assert isinstance(val, dict) is False, (
'The HeteroNodeDataView has only one node type. ' \ "The HeteroNodeDataView has only one node type. "
'please pass a tensor directly' "please pass a tensor directly"
)
self._graph._set_n_repr(self._ntid, self._nodes, {key: val}) self._graph._set_n_repr(self._ntid, self._nodes, {key: val})
def __delitem__(self, key): def __delitem__(self, key):
...@@ -108,8 +117,10 @@ class HeteroNodeDataView(MutableMapping): ...@@ -108,8 +117,10 @@ class HeteroNodeDataView(MutableMapping):
else: else:
ret = self._graph._get_n_repr(self._ntid, self._nodes) ret = self._graph._get_n_repr(self._ntid, self._nodes)
if as_dict: if as_dict:
ret = {key: ret[key] ret = {
for key in self._graph._node_frames[self._ntid]} key: ret[key]
for key in self._graph._node_frames[self._ntid]
}
return ret return ret
def __len__(self): def __len__(self):
...@@ -130,7 +141,8 @@ class HeteroNodeDataView(MutableMapping): ...@@ -130,7 +141,8 @@ class HeteroNodeDataView(MutableMapping):
class HeteroEdgeView(object): class HeteroEdgeView(object):
"""A EdgeView class to act as G.edges for a DGLHeteroGraph.""" """A EdgeView class to act as G.edges for a DGLHeteroGraph."""
__slots__ = ['_graph']
__slots__ = ["_graph"]
def __init__(self, graph): def __init__(self, graph):
self._graph = graph self._graph = graph
...@@ -138,8 +150,9 @@ class HeteroEdgeView(object): ...@@ -138,8 +150,9 @@ class HeteroEdgeView(object):
def __getitem__(self, key): def __getitem__(self, key):
if isinstance(key, slice): if isinstance(key, slice):
# slice # slice
if not (key.start is None and key.stop is None if not (
and key.step is None): key.start is None and key.stop is None and key.step is None
):
raise DGLError('Currently only full slice ":" is supported') raise DGLError('Currently only full slice ":" is supported')
edges = ALL edges = ALL
etype = None etype = None
...@@ -168,23 +181,26 @@ class HeteroEdgeView(object): ...@@ -168,23 +181,26 @@ class HeteroEdgeView(object):
class HeteroEdgeDataView(MutableMapping): class HeteroEdgeDataView(MutableMapping):
"""The data view class when G.edata[etype] is called.""" """The data view class when G.edata[etype] is called."""
__slots__ = ['_graph', '_etype', '_etid', '_edges']
__slots__ = ["_graph", "_etype", "_etid", "_edges"]
def __init__(self, graph, etype, edges): def __init__(self, graph, etype, edges):
self._graph = graph self._graph = graph
self._etype = etype self._etype = etype
self._etid = [self._graph.get_etype_id(t) for t in etype] \ self._etid = (
if isinstance(etype, list) \ [self._graph.get_etype_id(t) for t in etype]
if isinstance(etype, list)
else self._graph.get_etype_id(etype) else self._graph.get_etype_id(etype)
)
self._edges = edges self._edges = edges
def __getitem__(self, key): def __getitem__(self, key):
if isinstance(self._etype, list): if isinstance(self._etype, list):
ret = {} ret = {}
for (i, etype) in enumerate(self._etype): for (i, etype) in enumerate(self._etype):
value = self._graph._get_e_repr( value = self._graph._get_e_repr(self._etid[i], self._edges).get(
self._etid[i], self._edges).get( key, None
key, None) )
if value is not None: if value is not None:
ret[etype] = value ret[etype] = value
return ret return ret
...@@ -195,17 +211,19 @@ class HeteroEdgeDataView(MutableMapping): ...@@ -195,17 +211,19 @@ class HeteroEdgeDataView(MutableMapping):
if isinstance(val, LazyFeature): if isinstance(val, LazyFeature):
self._graph._edge_frames[self._etid][key] = val self._graph._edge_frames[self._etid][key] = val
elif isinstance(self._etype, list): elif isinstance(self._etype, list):
assert isinstance(val, dict), \ assert isinstance(val, dict), (
'Current HeteroEdgeDataView has multiple edge types, ' \ "Current HeteroEdgeDataView has multiple edge types, "
'please pass the edge type and the corresponding data through a dict.' "please pass the edge type and the corresponding data through a dict."
)
for (etype, data) in val.items(): for (etype, data) in val.items():
etid = self._graph.get_etype_id(etype) etid = self._graph.get_etype_id(etype)
self._graph._set_e_repr(etid, self._edges, {key: data}) self._graph._set_e_repr(etid, self._edges, {key: data})
else: else:
assert isinstance(val, dict) is False, \ assert isinstance(val, dict) is False, (
'The HeteroEdgeDataView has only one edge type. ' \ "The HeteroEdgeDataView has only one edge type. "
'please pass a tensor directly' "please pass a tensor directly"
)
self._graph._set_e_repr(self._etid, self._edges, {key: val}) self._graph._set_e_repr(self._etid, self._edges, {key: val})
def __delitem__(self, key): def __delitem__(self, key):
...@@ -227,8 +245,10 @@ class HeteroEdgeDataView(MutableMapping): ...@@ -227,8 +245,10 @@ class HeteroEdgeDataView(MutableMapping):
else: else:
ret = self._graph._get_e_repr(self._etid, self._edges) ret = self._graph._get_e_repr(self._etid, self._edges)
if as_dict: if as_dict:
ret = {key: ret[key] ret = {
for key in self._graph._edge_frames[self._etid]} key: ret[key]
for key in self._graph._edge_frames[self._etid]
}
return ret return ret
def __len__(self): def __len__(self):
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import sys import glob
import os import os
import platform import platform
import sysconfig
import shutil import shutil
import glob import sys
import sysconfig
from setuptools import find_packages from setuptools import find_packages
from setuptools.dist import Distribution from setuptools.dist import Distribution
# need to use distutils.core for correct placement of cython dll # need to use distutils.core for correct placement of cython dll
if '--inplace' in sys.argv: if "--inplace" in sys.argv:
from distutils.core import setup from distutils.core import setup
from distutils.extension import Extension from distutils.extension import Extension
else: else:
...@@ -31,34 +31,35 @@ def get_lib_path(): ...@@ -31,34 +31,35 @@ def get_lib_path():
"""Get library path, name and version""" """Get library path, name and version"""
# We can not import `libinfo.py` in setup.py directly since __init__.py # We can not import `libinfo.py` in setup.py directly since __init__.py
# Will be invoked which introduces dependences # Will be invoked which introduces dependences
libinfo_py = os.path.join(CURRENT_DIR, './dgl/_ffi/libinfo.py') libinfo_py = os.path.join(CURRENT_DIR, "./dgl/_ffi/libinfo.py")
libinfo = {'__file__': libinfo_py} libinfo = {"__file__": libinfo_py}
exec( exec(
compile(open(libinfo_py, "rb").read(), libinfo_py, 'exec'), compile(open(libinfo_py, "rb").read(), libinfo_py, "exec"),
libinfo,
libinfo, libinfo,
libinfo) )
version = libinfo['__version__'] version = libinfo["__version__"]
lib_path = libinfo['find_lib_path']() lib_path = libinfo["find_lib_path"]()
libs = [lib_path[0]] libs = [lib_path[0]]
return libs, version return libs, version
def get_ta_lib_pattern(): def get_ta_lib_pattern():
if sys.platform.startswith('linux'): if sys.platform.startswith("linux"):
ta_lib_pattern = 'libtensoradapter_*.so' ta_lib_pattern = "libtensoradapter_*.so"
elif sys.platform.startswith('darwin'): elif sys.platform.startswith("darwin"):
ta_lib_pattern = 'libtensoradapter_*.dylib' ta_lib_pattern = "libtensoradapter_*.dylib"
elif sys.platform.startswith('win'): elif sys.platform.startswith("win"):
ta_lib_pattern = 'tensoradapter_*.dll' ta_lib_pattern = "tensoradapter_*.dll"
else: else:
raise NotImplementedError('Unsupported system: %s' % sys.platform) raise NotImplementedError("Unsupported system: %s" % sys.platform)
return ta_lib_pattern return ta_lib_pattern
LIBS, VERSION = get_lib_path() LIBS, VERSION = get_lib_path()
BACKENDS = ['pytorch'] BACKENDS = ["pytorch"]
TA_LIB_PATTERN = get_ta_lib_pattern() TA_LIB_PATTERN = get_ta_lib_pattern()
...@@ -78,11 +79,9 @@ def cleanup(): ...@@ -78,11 +79,9 @@ def cleanup():
for backend in BACKENDS: for backend in BACKENDS:
for ta_path in glob.glob( for ta_path in glob.glob(
os.path.join( os.path.join(
CURRENT_DIR, CURRENT_DIR, "dgl", "tensoradapter", backend, TA_LIB_PATTERN
"dgl", )
"tensoradapter", ):
backend,
TA_LIB_PATTERN)):
try: try:
os.remove(ta_path) os.remove(ta_path)
except BaseException: except BaseException:
...@@ -91,17 +90,21 @@ def cleanup(): ...@@ -91,17 +90,21 @@ def cleanup():
def config_cython(): def config_cython():
"""Try to configure cython and return cython configuration""" """Try to configure cython and return cython configuration"""
if sys.platform.startswith('win'): if sys.platform.startswith("win"):
print("WARNING: Cython is not supported on Windows, will compile without cython module") print(
"WARNING: Cython is not supported on Windows, will compile without cython module"
)
return [] return []
sys_cflags = sysconfig.get_config_var("CFLAGS") sys_cflags = sysconfig.get_config_var("CFLAGS")
if "i386" in sys_cflags and "x86_64" in sys_cflags: if "i386" in sys_cflags and "x86_64" in sys_cflags:
print( print(
"WARNING: Cython library may not be compiled correctly with both i386 and x64") "WARNING: Cython library may not be compiled correctly with both i386 and x64"
)
return [] return []
try: try:
from Cython.Build import cythonize from Cython.Build import cythonize
# from setuptools.extension import Extension # from setuptools.extension import Extension
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
subdir = "_cy3" subdir = "_cy3"
...@@ -109,32 +112,38 @@ def config_cython(): ...@@ -109,32 +112,38 @@ def config_cython():
subdir = "_cy2" subdir = "_cy2"
ret = [] ret = []
path = "dgl/_ffi/_cython" path = "dgl/_ffi/_cython"
library_dirs = ['dgl', '../build/Release', '../build'] library_dirs = ["dgl", "../build/Release", "../build"]
libraries = ['dgl'] libraries = ["dgl"]
for fn in os.listdir(path): for fn in os.listdir(path):
if not fn.endswith(".pyx"): if not fn.endswith(".pyx"):
continue continue
ret.append(Extension( ret.append(
"dgl._ffi.%s.%s" % (subdir, fn[:-4]), Extension(
["dgl/_ffi/_cython/%s" % fn], "dgl._ffi.%s.%s" % (subdir, fn[:-4]),
include_dirs=["../include/", ["dgl/_ffi/_cython/%s" % fn],
"../third_party/dmlc-core/include", include_dirs=[
"../third_party/dlpack/include", "../include/",
], "../third_party/dmlc-core/include",
library_dirs=library_dirs, "../third_party/dlpack/include",
libraries=libraries, ],
# Crashes without this flag with GCC 5.3.1 library_dirs=library_dirs,
extra_compile_args=["-std=c++11"], libraries=libraries,
language="c++")) # Crashes without this flag with GCC 5.3.1
extra_compile_args=["-std=c++11"],
language="c++",
)
)
return cythonize(ret, force=True) return cythonize(ret, force=True)
except ImportError: except ImportError:
print("WARNING: Cython is not installed, will compile without cython module") print(
"WARNING: Cython is not installed, will compile without cython module"
)
return [] return []
include_libs = False include_libs = False
wheel_include_libs = False wheel_include_libs = False
if "bdist_wheel" in sys.argv or os.getenv('CONDA_BUILD'): if "bdist_wheel" in sys.argv or os.getenv("CONDA_BUILD"):
wheel_include_libs = True wheel_include_libs = True
elif "clean" in sys.argv: elif "clean" in sys.argv:
cleanup() cleanup()
...@@ -147,78 +156,76 @@ setup_kwargs = {} ...@@ -147,78 +156,76 @@ setup_kwargs = {}
if wheel_include_libs: if wheel_include_libs:
with open("MANIFEST.in", "w") as fo: with open("MANIFEST.in", "w") as fo:
for path in LIBS: for path in LIBS:
shutil.copy(path, os.path.join(CURRENT_DIR, 'dgl')) shutil.copy(path, os.path.join(CURRENT_DIR, "dgl"))
dir_, libname = os.path.split(path) dir_, libname = os.path.split(path)
fo.write("include dgl/%s\n" % libname) fo.write("include dgl/%s\n" % libname)
for backend in BACKENDS: for backend in BACKENDS:
for ta_path in glob.glob( for ta_path in glob.glob(
os.path.join( os.path.join(dir_, "tensoradapter", backend, TA_LIB_PATTERN)
dir_, ):
"tensoradapter",
backend,
TA_LIB_PATTERN)):
ta_name = os.path.basename(ta_path) ta_name = os.path.basename(ta_path)
os.makedirs( os.makedirs(
os.path.join( os.path.join(CURRENT_DIR, "dgl", "tensoradapter", backend),
CURRENT_DIR, exist_ok=True,
'dgl', )
'tensoradapter',
backend),
exist_ok=True)
shutil.copy( shutil.copy(
os.path.join(dir_, 'tensoradapter', backend, ta_name), os.path.join(dir_, "tensoradapter", backend, ta_name),
os.path.join(CURRENT_DIR, 'dgl', 'tensoradapter', backend)) os.path.join(CURRENT_DIR, "dgl", "tensoradapter", backend),
)
fo.write( fo.write(
"include dgl/tensoradapter/%s/%s\n" % "include dgl/tensoradapter/%s/%s\n" % (backend, ta_name)
(backend, ta_name)) )
setup_kwargs = { setup_kwargs = {"include_package_data": True}
"include_package_data": True
}
# For source tree setup # For source tree setup
# Conda build also includes the binary library # Conda build also includes the binary library
if include_libs: if include_libs:
rpath = [os.path.relpath(path, CURRENT_DIR) for path in LIBS] rpath = [os.path.relpath(path, CURRENT_DIR) for path in LIBS]
data_files = [('dgl', rpath)] data_files = [("dgl", rpath)]
for path in LIBS: for path in LIBS:
for backend in BACKENDS: for backend in BACKENDS:
data_files.append(( data_files.append(
'dgl/tensoradapter/%s' % backend, (
glob.glob(os.path.join( "dgl/tensoradapter/%s" % backend,
os.path.dirname(os.path.relpath(path, CURRENT_DIR)), glob.glob(
'tensoradapter', backend, TA_LIB_PATTERN)))) os.path.join(
setup_kwargs = { os.path.dirname(os.path.relpath(path, CURRENT_DIR)),
"include_package_data": True, "tensoradapter",
"data_files": data_files backend,
} TA_LIB_PATTERN,
)
),
)
)
setup_kwargs = {"include_package_data": True, "data_files": data_files}
setup( setup(
name='dgl' + os.getenv('DGL_PACKAGE_SUFFIX', ''), name="dgl" + os.getenv("DGL_PACKAGE_SUFFIX", ""),
version=VERSION, version=VERSION,
description='Deep Graph Library', description="Deep Graph Library",
zip_safe=False, zip_safe=False,
maintainer='DGL Team', maintainer="DGL Team",
maintainer_email='wmjlyjemaine@gmail.com', maintainer_email="wmjlyjemaine@gmail.com",
packages=find_packages(), packages=find_packages(),
install_requires=[ install_requires=[
'numpy>=1.14.0', "numpy>=1.14.0",
'scipy>=1.1.0', "scipy>=1.1.0",
'networkx>=2.1', "networkx>=2.1",
'requests>=2.19.0', "requests>=2.19.0",
'tqdm', "tqdm",
'psutil>=5.8.0', "psutil>=5.8.0",
], ],
url='https://github.com/dmlc/dgl', url="https://github.com/dmlc/dgl",
distclass=BinaryDistribution, distclass=BinaryDistribution,
ext_modules=config_cython(), ext_modules=config_cython(),
classifiers=[ classifiers=[
'Development Status :: 3 - Alpha', "Development Status :: 3 - Alpha",
'Programming Language :: Python :: 3', "Programming Language :: Python :: 3",
'License :: OSI Approved :: Apache Software License', "License :: OSI Approved :: Apache Software License",
], ],
license='APACHE', license="APACHE",
**setup_kwargs **setup_kwargs
) )
......
...@@ -8,10 +8,11 @@ List of affected files: ...@@ -8,10 +8,11 @@ List of affected files:
""" """
import os import os
import re import re
# current version # current version
# We use the version of the incoming release for code # We use the version of the incoming release for code
# that is under development # that is under development
__version__ = "0.10" + os.getenv('DGL_PRERELEASE', '') __version__ = "0.10" + os.getenv("DGL_PRERELEASE", "")
print(__version__) print(__version__)
# Implementations # Implementations
...@@ -47,22 +48,24 @@ def main(): ...@@ -47,22 +48,24 @@ def main():
curr_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) curr_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_dir, "..")) proj_root = os.path.abspath(os.path.join(curr_dir, ".."))
# python path # python path
update(os.path.join(proj_root, "python", "dgl", "_ffi", "libinfo.py"), update(
r"(?<=__version__ = \")[.0-9a-z]+", __version__) os.path.join(proj_root, "python", "dgl", "_ffi", "libinfo.py"),
r"(?<=__version__ = \")[.0-9a-z]+",
__version__,
)
# C++ header # C++ header
update( update(
os.path.join( os.path.join(proj_root, "include", "dgl", "runtime", "c_runtime_api.h"),
proj_root, '(?<=DGL_VERSION ")[.0-9a-z]+',
"include", __version__,
"dgl", )
"runtime",
"c_runtime_api.h"),
"(?<=DGL_VERSION \")[.0-9a-z]+",
__version__)
# conda # conda
for path in ["dgl"]: for path in ["dgl"]:
update(os.path.join(proj_root, "conda", path, "meta.yaml"), update(
"(?<=version: \")[.0-9a-z]+", __version__) os.path.join(proj_root, "conda", path, "meta.yaml"),
'(?<=version: ")[.0-9a-z]+',
__version__,
)
if __name__ == "__main__": if __name__ == "__main__":
......
import torch
import os import os
import torch
cmake_prefix_path = getattr( cmake_prefix_path = getattr(
torch.utils, torch.utils,
"cmake_prefix_path", "cmake_prefix_path",
os.path.join(os.path.dirname(torch.__file__), "share", "cmake")) os.path.join(os.path.dirname(torch.__file__), "share", "cmake"),
version = torch.__version__.split('+')[0] )
print(';'.join([cmake_prefix_path, version])) version = torch.__version__.split("+")[0]
print(";".join([cmake_prefix_path, version]))
from dgl.backend import *
from dgl.nn import *
from . import backend_unittest
import os
import importlib import importlib
import os
import sys import sys
import numpy as np import numpy as np
mod = importlib.import_module('.%s' % backend_name, __name__) from dgl.backend import *
from dgl.nn import *
from . import backend_unittest
mod = importlib.import_module(".%s" % backend_name, __name__)
thismod = sys.modules[__name__] thismod = sys.modules[__name__]
for api in backend_unittest.__dict__.keys(): for api in backend_unittest.__dict__.keys():
if api.startswith('__'): if api.startswith("__"):
continue continue
elif callable(mod.__dict__[api]): elif callable(mod.__dict__[api]):
# Tensor APIs used in unit tests MUST be supported across all backends # Tensor APIs used in unit tests MUST be supported across all backends
...@@ -26,39 +29,51 @@ _arange = arange ...@@ -26,39 +29,51 @@ _arange = arange
_full = full _full = full
_full_1d = full_1d _full_1d = full_1d
_softmax = softmax _softmax = softmax
_default_context_str = os.getenv('DGLTESTDEV', 'cpu') _default_context_str = os.getenv("DGLTESTDEV", "cpu")
_context_dict = { _context_dict = {
'cpu': cpu(), "cpu": cpu(),
'gpu': cuda(), "gpu": cuda(),
} }
_default_context = _context_dict[_default_context_str] _default_context = _context_dict[_default_context_str]
def ctx(): def ctx():
return _default_context return _default_context
def gpu_ctx(): def gpu_ctx():
return (_default_context_str == 'gpu') return _default_context_str == "gpu"
def zeros(shape, dtype=float32, ctx=_default_context): def zeros(shape, dtype=float32, ctx=_default_context):
return _zeros(shape, dtype, ctx) return _zeros(shape, dtype, ctx)
def ones(shape, dtype=float32, ctx=_default_context): def ones(shape, dtype=float32, ctx=_default_context):
return _ones(shape, dtype, ctx) return _ones(shape, dtype, ctx)
def randn(shape): def randn(shape):
return copy_to(_randn(shape), _default_context) return copy_to(_randn(shape), _default_context)
def tensor(data, dtype=None): def tensor(data, dtype=None):
return copy_to(_tensor(data, dtype), _default_context) return copy_to(_tensor(data, dtype), _default_context)
def arange(start, stop, dtype=int64, ctx=None): def arange(start, stop, dtype=int64, ctx=None):
return _arange(start, stop, dtype, ctx if ctx is not None else _default_context) return _arange(
start, stop, dtype, ctx if ctx is not None else _default_context
)
def full(shape, fill_value, dtype, ctx=_default_context): def full(shape, fill_value, dtype, ctx=_default_context):
return _full(shape, fill_value, dtype, ctx) return _full(shape, fill_value, dtype, ctx)
def full_1d(length, fill_value, dtype, ctx=_default_context): def full_1d(length, fill_value, dtype, ctx=_default_context):
return _full_1d(length, fill_value, dtype, ctx) return _full_1d(length, fill_value, dtype, ctx)
def softmax(x, dim): def softmax(x, dim):
return _softmax(x, dim) return _softmax(x, dim)
...@@ -5,102 +5,127 @@ unit testing, other than the ones used in the framework itself. ...@@ -5,102 +5,127 @@ unit testing, other than the ones used in the framework itself.
############################################################################### ###############################################################################
# Tensor, data type and context interfaces # Tensor, data type and context interfaces
def cuda(): def cuda():
"""Context object for CUDA.""" """Context object for CUDA."""
pass pass
def is_cuda_available(): def is_cuda_available():
"""Check whether CUDA is available.""" """Check whether CUDA is available."""
pass pass
############################################################################### ###############################################################################
# Tensor functions on feature data # Tensor functions on feature data
# -------------------------------- # --------------------------------
# These functions are performance critical, so it's better to have efficient # These functions are performance critical, so it's better to have efficient
# implementation in each framework. # implementation in each framework.
def array_equal(a, b): def array_equal(a, b):
"""Check whether the two tensors are *exactly* equal.""" """Check whether the two tensors are *exactly* equal."""
pass pass
def allclose(a, b, rtol=1e-4, atol=1e-4): def allclose(a, b, rtol=1e-4, atol=1e-4):
"""Check whether the two tensors are numerically close to each other.""" """Check whether the two tensors are numerically close to each other."""
pass pass
def randn(shape): def randn(shape):
"""Generate a tensor with elements from standard normal distribution.""" """Generate a tensor with elements from standard normal distribution."""
pass pass
def full(shape, fill_value, dtype, ctx): def full(shape, fill_value, dtype, ctx):
pass pass
def narrow_row_set(x, start, stop, new): def narrow_row_set(x, start, stop, new):
"""Set a slice of the given tensor to a new value.""" """Set a slice of the given tensor to a new value."""
pass pass
def sparse_to_numpy(x): def sparse_to_numpy(x):
"""Convert a sparse tensor to a numpy array.""" """Convert a sparse tensor to a numpy array."""
pass pass
def clone(x): def clone(x):
pass pass
def reduce_sum(x): def reduce_sum(x):
"""Sums all the elements into a single scalar.""" """Sums all the elements into a single scalar."""
pass pass
def softmax(x, dim): def softmax(x, dim):
"""Softmax Operation on Tensors""" """Softmax Operation on Tensors"""
pass pass
def spmm(x, y): def spmm(x, y):
"""Sparse dense matrix multiply""" """Sparse dense matrix multiply"""
pass pass
def add(a, b): def add(a, b):
"""Compute a + b""" """Compute a + b"""
pass pass
def sub(a, b): def sub(a, b):
"""Compute a - b""" """Compute a - b"""
pass pass
def mul(a, b): def mul(a, b):
"""Compute a * b""" """Compute a * b"""
pass pass
def div(a, b): def div(a, b):
"""Compute a / b""" """Compute a / b"""
pass pass
def sum(x, dim, keepdims=False): def sum(x, dim, keepdims=False):
"""Computes the sum of array elements over given axes""" """Computes the sum of array elements over given axes"""
pass pass
def max(x, dim): def max(x, dim):
"""Computes the max of array elements over given axes""" """Computes the max of array elements over given axes"""
pass pass
def min(x, dim): def min(x, dim):
"""Computes the min of array elements over given axes""" """Computes the min of array elements over given axes"""
pass pass
def prod(x, dim): def prod(x, dim):
"""Computes the prod of array elements over given axes""" """Computes the prod of array elements over given axes"""
pass pass
def matmul(a, b): def matmul(a, b):
"""Compute Matrix Multiplication between a and b""" """Compute Matrix Multiplication between a and b"""
pass pass
def dot(a, b): def dot(a, b):
"""Compute Dot between a and b""" """Compute Dot between a and b"""
pass pass
def abs(a): def abs(a):
"""Compute the absolute value of a""" """Compute the absolute value of a"""
pass pass
############################################################################### ###############################################################################
# Tensor functions used *only* on index tensor # Tensor functions used *only* on index tensor
# ---------------- # ----------------
......
from __future__ import absolute_import from __future__ import absolute_import
import numpy as np
import mxnet as mx import mxnet as mx
import mxnet.ndarray as nd import mxnet.ndarray as nd
import numpy as np
def cuda(): def cuda():
return mx.gpu() return mx.gpu()
def is_cuda_available(): def is_cuda_available():
# TODO: Does MXNet have a convenient function to test GPU availability/compilation? # TODO: Does MXNet have a convenient function to test GPU availability/compilation?
try: try:
...@@ -15,65 +17,86 @@ def is_cuda_available(): ...@@ -15,65 +17,86 @@ def is_cuda_available():
except mx.MXNetError: except mx.MXNetError:
return False return False
def array_equal(a, b): def array_equal(a, b):
return nd.equal(a, b).asnumpy().all() return nd.equal(a, b).asnumpy().all()
def allclose(a, b, rtol=1e-4, atol=1e-4): def allclose(a, b, rtol=1e-4, atol=1e-4):
return np.allclose(a.asnumpy(), b.asnumpy(), rtol=rtol, atol=atol) return np.allclose(a.asnumpy(), b.asnumpy(), rtol=rtol, atol=atol)
def randn(shape): def randn(shape):
return nd.random.randn(*shape) return nd.random.randn(*shape)
def full(shape, fill_value, dtype, ctx): def full(shape, fill_value, dtype, ctx):
return nd.full(shape, fill_value, dtype=dtype, ctx=ctx) return nd.full(shape, fill_value, dtype=dtype, ctx=ctx)
def narrow_row_set(x, start, stop, new): def narrow_row_set(x, start, stop, new):
x[start:stop] = new x[start:stop] = new
def sparse_to_numpy(x): def sparse_to_numpy(x):
return x.asscipy().todense().A return x.asscipy().todense().A
def clone(x): def clone(x):
return x.copy() return x.copy()
def reduce_sum(x): def reduce_sum(x):
return x.sum() return x.sum()
def softmax(x, dim): def softmax(x, dim):
return nd.softmax(x, axis=dim) return nd.softmax(x, axis=dim)
def spmm(x, y): def spmm(x, y):
return nd.dot(x, y) return nd.dot(x, y)
def add(a, b): def add(a, b):
return a + b return a + b
def sub(a, b): def sub(a, b):
return a - b return a - b
def mul(a, b): def mul(a, b):
return a * b return a * b
def div(a, b): def div(a, b):
return a / b return a / b
def sum(x, dim, keepdims=False): def sum(x, dim, keepdims=False):
return x.sum(dim, keepdims=keepdims) return x.sum(dim, keepdims=keepdims)
def max(x, dim): def max(x, dim):
return x.max(dim) return x.max(dim)
def min(x, dim): def min(x, dim):
return x.min(dim) return x.min(dim)
def prod(x, dim): def prod(x, dim):
return x.prod(dim) return x.prod(dim)
def matmul(a, b): def matmul(a, b):
return nd.dot(a, b) return nd.dot(a, b)
def dot(a, b): def dot(a, b):
return nd.sum(mul(a, b), axis=-1) return nd.sum(mul(a, b), axis=-1)
def abs(a): def abs(a):
return nd.abs(a) return nd.abs(a)
...@@ -2,72 +2,94 @@ from __future__ import absolute_import ...@@ -2,72 +2,94 @@ from __future__ import absolute_import
import torch as th import torch as th
def cuda(): def cuda():
return th.device('cuda:0') return th.device("cuda:0")
def is_cuda_available(): def is_cuda_available():
return th.cuda.is_available() return th.cuda.is_available()
def array_equal(a, b): def array_equal(a, b):
return th.equal(a.cpu(), b.cpu()) return th.equal(a.cpu(), b.cpu())
def allclose(a, b, rtol=1e-4, atol=1e-4): def allclose(a, b, rtol=1e-4, atol=1e-4):
return th.allclose(a.float().cpu(), return th.allclose(a.float().cpu(), b.float().cpu(), rtol=rtol, atol=atol)
b.float().cpu(), rtol=rtol, atol=atol)
def randn(shape): def randn(shape):
return th.randn(*shape) return th.randn(*shape)
def full(shape, fill_value, dtype, ctx): def full(shape, fill_value, dtype, ctx):
return th.full(shape, fill_value, dtype=dtype, device=ctx) return th.full(shape, fill_value, dtype=dtype, device=ctx)
def narrow_row_set(x, start, stop, new): def narrow_row_set(x, start, stop, new):
x[start:stop] = new x[start:stop] = new
def sparse_to_numpy(x): def sparse_to_numpy(x):
return x.to_dense().numpy() return x.to_dense().numpy()
def clone(x): def clone(x):
return x.clone() return x.clone()
def reduce_sum(x): def reduce_sum(x):
return x.sum() return x.sum()
def softmax(x, dim): def softmax(x, dim):
return th.softmax(x, dim) return th.softmax(x, dim)
def spmm(x, y): def spmm(x, y):
return th.spmm(x, y) return th.spmm(x, y)
def add(a, b): def add(a, b):
return a + b return a + b
def sub(a, b): def sub(a, b):
return a - b return a - b
def mul(a, b): def mul(a, b):
return a * b return a * b
def div(a, b): def div(a, b):
return a / b return a / b
def sum(x, dim, keepdims=False): def sum(x, dim, keepdims=False):
return x.sum(dim, keepdims=keepdims) return x.sum(dim, keepdims=keepdims)
def max(x, dim): def max(x, dim):
return x.max(dim)[0] return x.max(dim)[0]
def min(x, dim): def min(x, dim):
return x.min(dim)[0] return x.min(dim)[0]
def prod(x, dim): def prod(x, dim):
return x.prod(dim) return x.prod(dim)
def matmul(a, b): def matmul(a, b):
return a @ b return a @ b
def dot(a, b): def dot(a, b):
return sum(mul(a, b), dim=-1) return sum(mul(a, b), dim=-1)
def abs(a): def abs(a):
return a.abs() return a.abs()
...@@ -6,7 +6,7 @@ from scipy.sparse import coo_matrix ...@@ -6,7 +6,7 @@ from scipy.sparse import coo_matrix
def cuda(): def cuda():
return '/gpu:0' return "/gpu:0"
def is_cuda_available(): def is_cuda_available():
...@@ -18,8 +18,12 @@ def array_equal(a, b): ...@@ -18,8 +18,12 @@ def array_equal(a, b):
def allclose(a, b, rtol=1e-4, atol=1e-4): def allclose(a, b, rtol=1e-4, atol=1e-4):
return np.allclose(tf.convert_to_tensor(a).numpy(), return np.allclose(
tf.convert_to_tensor(b).numpy(), rtol=rtol, atol=atol) tf.convert_to_tensor(a).numpy(),
tf.convert_to_tensor(b).numpy(),
rtol=rtol,
atol=atol,
)
def randn(shape): def randn(shape):
...@@ -97,5 +101,6 @@ def matmul(a, b): ...@@ -97,5 +101,6 @@ def matmul(a, b):
def dot(a, b): def dot(a, b):
return sum(mul(a, b), dim=-1) return sum(mul(a, b), dim=-1)
def abs(a): def abs(a):
return tf.abs(a) return tf.abs(a)
import dgl
import dgl.function as fn
from collections import Counter
import numpy as np
import scipy.sparse as ssp
import itertools import itertools
import unittest
from collections import Counter
from itertools import product from itertools import product
import backend as F import backend as F
import networkx as nx import networkx as nx
import unittest, pytest import numpy as np
from dgl import DGLError import pytest
import scipy.sparse as ssp
import test_utils import test_utils
from test_utils import parametrize_idtype, get_cases
from scipy.sparse import rand from scipy.sparse import rand
from test_utils import get_cases, parametrize_idtype
import dgl
import dgl.function as fn
from dgl import DGLError
rfuncs = {'sum': fn.sum, 'max': fn.max, 'min': fn.min, 'mean': fn.mean} rfuncs = {"sum": fn.sum, "max": fn.max, "min": fn.min, "mean": fn.mean}
fill_value = {'sum': 0, 'max': float("-inf")} fill_value = {"sum": 0, "max": float("-inf")}
feat_size = 2 feat_size = 2
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
def create_test_heterograph(idtype): def create_test_heterograph(idtype):
# test heterograph from the docstring, plus a user -- wishes -- game relation # test heterograph from the docstring, plus a user -- wishes -- game relation
# 3 users, 2 games, 2 developers # 3 users, 2 games, 2 developers
...@@ -28,12 +33,16 @@ def create_test_heterograph(idtype): ...@@ -28,12 +33,16 @@ def create_test_heterograph(idtype):
# ('user', 'wishes', 'game'), # ('user', 'wishes', 'game'),
# ('developer', 'develops', 'game')]) # ('developer', 'develops', 'game')])
g = dgl.heterograph({ g = dgl.heterograph(
('user', 'follows', 'user'): ([0, 1, 2, 1], [0, 0, 1, 1]), {
('user', 'plays', 'game'): ([0, 1, 2, 1], [0, 0, 1, 1]), ("user", "follows", "user"): ([0, 1, 2, 1], [0, 0, 1, 1]),
('user', 'wishes', 'game'): ([0, 1, 1], [0, 0, 1]), ("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
('developer', 'develops', 'game'): ([0, 1, 0], [0, 1, 1]), ("user", "wishes", "game"): ([0, 1, 1], [0, 0, 1]),
}, idtype=idtype, device=F.ctx()) ("developer", "develops", "game"): ([0, 1, 0], [0, 1, 1]),
},
idtype=idtype,
device=F.ctx(),
)
assert g.idtype == idtype assert g.idtype == idtype
assert g.device == F.ctx() assert g.device == F.ctx()
return g return g
...@@ -45,49 +54,53 @@ def test_unary_copy_u(idtype): ...@@ -45,49 +54,53 @@ def test_unary_copy_u(idtype):
g = create_test_heterograph(idtype) g = create_test_heterograph(idtype)
x1 = F.randn((g.num_nodes('user'), feat_size)) x1 = F.randn((g.num_nodes("user"), feat_size))
x2 = F.randn((g.num_nodes('developer'), feat_size)) x2 = F.randn((g.num_nodes("developer"), feat_size))
F.attach_grad(x1) F.attach_grad(x1)
F.attach_grad(x2) F.attach_grad(x2)
g.nodes['user'].data['h'] = x1 g.nodes["user"].data["h"] = x1
g.nodes['developer'].data['h'] = x2 g.nodes["developer"].data["h"] = x2
################################################################# #################################################################
# apply_edges() is called on each relation type separately # apply_edges() is called on each relation type separately
################################################################# #################################################################
with F.record_grad(): with F.record_grad():
[g.apply_edges(fn.copy_u('h', 'm'), etype = rel) [
for rel in g.canonical_etypes] g.apply_edges(fn.copy_u("h", "m"), etype=rel)
r1 = g['plays'].edata['m'] for rel in g.canonical_etypes
]
r1 = g["plays"].edata["m"]
F.backward(r1, F.ones(r1.shape)) F.backward(r1, F.ones(r1.shape))
n_grad1 = F.grad(g.ndata['h']['user']) n_grad1 = F.grad(g.ndata["h"]["user"])
# TODO (Israt): clear not working # TODO (Israt): clear not working
g.edata['m'].clear() g.edata["m"].clear()
################################################################# #################################################################
# apply_edges() is called on all relation types # apply_edges() is called on all relation types
################################################################# #################################################################
g.apply_edges(fn.copy_u('h', 'm')) g.apply_edges(fn.copy_u("h", "m"))
r2 = g['plays'].edata['m'] r2 = g["plays"].edata["m"]
F.backward(r2, F.ones(r2.shape)) F.backward(r2, F.ones(r2.shape))
n_grad2 = F.grad(g.nodes['user'].data['h']) n_grad2 = F.grad(g.nodes["user"].data["h"])
# correctness check # correctness check
def _print_error(a, b): def _print_error(a, b):
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())): for i, (x, y) in enumerate(
zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
):
if not np.allclose(x, y): if not np.allclose(x, y):
print('@{} {} v.s. {}'.format(i, x, y)) print("@{} {} v.s. {}".format(i, x, y))
if not F.allclose(r1, r2): if not F.allclose(r1, r2):
_print_error(r1, r2) _print_error(r1, r2)
assert F.allclose(r1, r2) assert F.allclose(r1, r2)
if not F.allclose(n_grad1, n_grad2): if not F.allclose(n_grad1, n_grad2):
print('node grad') print("node grad")
_print_error(n_grad1, n_grad2) _print_error(n_grad1, n_grad2)
assert(F.allclose(n_grad1, n_grad2)) assert F.allclose(n_grad1, n_grad2)
_test(fn.copy_u) _test(fn.copy_u)
...@@ -99,51 +112,55 @@ def test_unary_copy_e(idtype): ...@@ -99,51 +112,55 @@ def test_unary_copy_e(idtype):
g = create_test_heterograph(idtype) g = create_test_heterograph(idtype)
feat_size = 2 feat_size = 2
x1 = F.randn((4,feat_size)) x1 = F.randn((4, feat_size))
x2 = F.randn((4,feat_size)) x2 = F.randn((4, feat_size))
x3 = F.randn((3,feat_size)) x3 = F.randn((3, feat_size))
x4 = F.randn((3,feat_size)) x4 = F.randn((3, feat_size))
F.attach_grad(x1) F.attach_grad(x1)
F.attach_grad(x2) F.attach_grad(x2)
F.attach_grad(x3) F.attach_grad(x3)
F.attach_grad(x4) F.attach_grad(x4)
g['plays'].edata['eid'] = x1 g["plays"].edata["eid"] = x1
g['follows'].edata['eid'] = x2 g["follows"].edata["eid"] = x2
g['develops'].edata['eid'] = x3 g["develops"].edata["eid"] = x3
g['wishes'].edata['eid'] = x4 g["wishes"].edata["eid"] = x4
################################################################# #################################################################
# apply_edges() is called on each relation type separately # apply_edges() is called on each relation type separately
################################################################# #################################################################
with F.record_grad(): with F.record_grad():
[g.apply_edges(fn.copy_e('eid', 'm'), etype = rel) [
for rel in g.canonical_etypes] g.apply_edges(fn.copy_e("eid", "m"), etype=rel)
r1 = g['develops'].edata['m'] for rel in g.canonical_etypes
]
r1 = g["develops"].edata["m"]
F.backward(r1, F.ones(r1.shape)) F.backward(r1, F.ones(r1.shape))
e_grad1 = F.grad(g['develops'].edata['eid']) e_grad1 = F.grad(g["develops"].edata["eid"])
################################################################# #################################################################
# apply_edges() is called on all relation types # apply_edges() is called on all relation types
################################################################# #################################################################
g.apply_edges(fn.copy_e('eid', 'm')) g.apply_edges(fn.copy_e("eid", "m"))
r2 = g['develops'].edata['m'] r2 = g["develops"].edata["m"]
F.backward(r2, F.ones(r2.shape)) F.backward(r2, F.ones(r2.shape))
e_grad2 = F.grad(g['develops'].edata['eid']) e_grad2 = F.grad(g["develops"].edata["eid"])
# # correctness check # # correctness check
def _print_error(a, b): def _print_error(a, b):
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())): for i, (x, y) in enumerate(
zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
):
if not np.allclose(x, y): if not np.allclose(x, y):
print('@{} {} v.s. {}'.format(i, x, y)) print("@{} {} v.s. {}".format(i, x, y))
if not F.allclose(r1, r2): if not F.allclose(r1, r2):
_print_error(r1, r2) _print_error(r1, r2)
assert F.allclose(r1, r2) assert F.allclose(r1, r2)
if not F.allclose(e_grad1, e_grad2): if not F.allclose(e_grad1, e_grad2):
print('edge grad') print("edge grad")
_print_error(e_grad1, e_grad2) _print_error(e_grad1, e_grad2)
assert(F.allclose(e_grad1, e_grad2)) assert F.allclose(e_grad1, e_grad2)
_test(fn.copy_e) _test(fn.copy_e)
...@@ -154,14 +171,14 @@ def test_binary_op(idtype): ...@@ -154,14 +171,14 @@ def test_binary_op(idtype):
g = create_test_heterograph(idtype) g = create_test_heterograph(idtype)
n1 = F.randn((g.num_nodes('user'), feat_size)) n1 = F.randn((g.num_nodes("user"), feat_size))
n2 = F.randn((g.num_nodes('developer'), feat_size)) n2 = F.randn((g.num_nodes("developer"), feat_size))
n3 = F.randn((g.num_nodes('game'), feat_size)) n3 = F.randn((g.num_nodes("game"), feat_size))
x1 = F.randn((g.num_edges('plays'),feat_size)) x1 = F.randn((g.num_edges("plays"), feat_size))
x2 = F.randn((g.num_edges('follows'),feat_size)) x2 = F.randn((g.num_edges("follows"), feat_size))
x3 = F.randn((g.num_edges('develops'),feat_size)) x3 = F.randn((g.num_edges("develops"), feat_size))
x4 = F.randn((g.num_edges('wishes'),feat_size)) x4 = F.randn((g.num_edges("wishes"), feat_size))
builtin_msg_name = "{}_{}_{}".format(lhs, binary_op, rhs) builtin_msg_name = "{}_{}_{}".format(lhs, binary_op, rhs)
builtin_msg = getattr(fn, builtin_msg_name) builtin_msg = getattr(fn, builtin_msg_name)
...@@ -173,25 +190,27 @@ def test_binary_op(idtype): ...@@ -173,25 +190,27 @@ def test_binary_op(idtype):
F.attach_grad(n1) F.attach_grad(n1)
F.attach_grad(n2) F.attach_grad(n2)
F.attach_grad(n3) F.attach_grad(n3)
g.nodes['user'].data['h'] = n1 g.nodes["user"].data["h"] = n1
g.nodes['developer'].data['h'] = n2 g.nodes["developer"].data["h"] = n2
g.nodes['game'].data['h'] = n3 g.nodes["game"].data["h"] = n3
F.attach_grad(x1) F.attach_grad(x1)
F.attach_grad(x2) F.attach_grad(x2)
F.attach_grad(x3) F.attach_grad(x3)
F.attach_grad(x4) F.attach_grad(x4)
g['plays'].edata['h'] = x1 g["plays"].edata["h"] = x1
g['follows'].edata['h'] = x2 g["follows"].edata["h"] = x2
g['develops'].edata['h'] = x3 g["develops"].edata["h"] = x3
g['wishes'].edata['h'] = x4 g["wishes"].edata["h"] = x4
with F.record_grad(): with F.record_grad():
[g.apply_edges(builtin_msg('h', 'h', 'm'), etype = rel) [
for rel in g.canonical_etypes] g.apply_edges(builtin_msg("h", "h", "m"), etype=rel)
r1 = g['plays'].edata['m'] for rel in g.canonical_etypes
]
r1 = g["plays"].edata["m"]
loss = F.sum(r1.view(-1), 0) loss = F.sum(r1.view(-1), 0)
F.backward(loss) F.backward(loss)
n_grad1 = F.grad(g.nodes['game'].data['h']) n_grad1 = F.grad(g.nodes["game"].data["h"])
################################################################# #################################################################
# apply_edges() is called on all relation types # apply_edges() is called on all relation types
...@@ -200,38 +219,40 @@ def test_binary_op(idtype): ...@@ -200,38 +219,40 @@ def test_binary_op(idtype):
F.attach_grad(n1) F.attach_grad(n1)
F.attach_grad(n2) F.attach_grad(n2)
F.attach_grad(n3) F.attach_grad(n3)
g.nodes['user'].data['h'] = n1 g.nodes["user"].data["h"] = n1
g.nodes['developer'].data['h'] = n2 g.nodes["developer"].data["h"] = n2
g.nodes['game'].data['h'] = n3 g.nodes["game"].data["h"] = n3
F.attach_grad(x1) F.attach_grad(x1)
F.attach_grad(x2) F.attach_grad(x2)
F.attach_grad(x3) F.attach_grad(x3)
F.attach_grad(x4) F.attach_grad(x4)
g['plays'].edata['h'] = x1 g["plays"].edata["h"] = x1
g['follows'].edata['h'] = x2 g["follows"].edata["h"] = x2
g['develops'].edata['h'] = x3 g["develops"].edata["h"] = x3
g['wishes'].edata['h'] = x4 g["wishes"].edata["h"] = x4
with F.record_grad(): with F.record_grad():
g.apply_edges(builtin_msg('h', 'h', 'm')) g.apply_edges(builtin_msg("h", "h", "m"))
r2 = g['plays'].edata['m'] r2 = g["plays"].edata["m"]
loss = F.sum(r2.view(-1), 0) loss = F.sum(r2.view(-1), 0)
F.backward(loss) F.backward(loss)
n_grad2 = F.grad(g.nodes['game'].data['h']) n_grad2 = F.grad(g.nodes["game"].data["h"])
# correctness check # correctness check
def _print_error(a, b): def _print_error(a, b):
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())): for i, (x, y) in enumerate(
zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
):
if not np.allclose(x, y): if not np.allclose(x, y):
print('@{} {} v.s. {}'.format(i, x, y)) print("@{} {} v.s. {}".format(i, x, y))
if not F.allclose(r1, r2): if not F.allclose(r1, r2):
_print_error(r1, r2) _print_error(r1, r2)
assert F.allclose(r1, r2) assert F.allclose(r1, r2)
if n_grad1 is not None or n_grad2 is not None: if n_grad1 is not None or n_grad2 is not None:
if not F.allclose(n_grad1, n_grad2): if not F.allclose(n_grad1, n_grad2):
print('node grad') print("node grad")
_print_error(n_grad1, n_grad2) _print_error(n_grad1, n_grad2)
assert(F.allclose(n_grad1, n_grad2)) assert F.allclose(n_grad1, n_grad2)
target = ["u", "v", "e"] target = ["u", "v", "e"]
for lhs, rhs in product(target, target): for lhs, rhs in product(target, target):
...@@ -242,6 +263,6 @@ def test_binary_op(idtype): ...@@ -242,6 +263,6 @@ def test_binary_op(idtype):
_test(lhs, rhs, binary_op) _test(lhs, rhs, binary_op)
if __name__ == '__main__': if __name__ == "__main__":
test_unary_copy_u() test_unary_copy_u()
test_unary_copy_e() test_unary_copy_e()
import backend as F
import os import os
import unittest import unittest
import backend as F
def test_set_default_backend(): def test_set_default_backend():
default_dir = os.path.join(os.path.expanduser('~'), '.dgl_unit_test') default_dir = os.path.join(os.path.expanduser("~"), ".dgl_unit_test")
F.set_default_backend(default_dir, 'pytorch') F.set_default_backend(default_dir, "pytorch")
# make sure the config file was created # make sure the config file was created
assert os.path.exists(os.path.join(default_dir, 'config.json')) assert os.path.exists(os.path.join(default_dir, "config.json"))
This diff is collapsed.
import unittest import unittest
import backend as F import backend as F
import dgl
from dgl.dataloading import NeighborSampler, negative_sampler, \
as_edge_prediction_sampler
from test_utils import parametrize_idtype from test_utils import parametrize_idtype
import dgl
from dgl.dataloading import (
NeighborSampler,
as_edge_prediction_sampler,
negative_sampler,
)
def create_test_graph(idtype): def create_test_graph(idtype):
# test heterograph from the docstring, plus a user -- wishes -- game relation # test heterograph from the docstring, plus a user -- wishes -- game relation
# 3 users, 2 games, 2 developers # 3 users, 2 games, 2 developers
...@@ -14,12 +20,16 @@ def create_test_graph(idtype): ...@@ -14,12 +20,16 @@ def create_test_graph(idtype):
# ('user', 'wishes', 'game'), # ('user', 'wishes', 'game'),
# ('developer', 'develops', 'game')]) # ('developer', 'develops', 'game')])
g = dgl.heterograph({ g = dgl.heterograph(
('user', 'follows', 'user'): ([0, 1], [1, 2]), {
('user', 'plays', 'game'): ([0, 1, 2, 1], [0, 0, 1, 1]), ("user", "follows", "user"): ([0, 1], [1, 2]),
('user', 'wishes', 'game'): ([0, 2], [1, 0]), ("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
('developer', 'develops', 'game'): ([0, 1], [0, 1]) ("user", "wishes", "game"): ([0, 2], [1, 0]),
}, idtype=idtype, device=F.ctx()) ("developer", "develops", "game"): ([0, 1], [0, 1]),
},
idtype=idtype,
device=F.ctx(),
)
assert g.idtype == idtype assert g.idtype == idtype
assert g.device == F.ctx() assert g.device == F.ctx()
return g return g
...@@ -28,14 +38,15 @@ def create_test_graph(idtype): ...@@ -28,14 +38,15 @@ def create_test_graph(idtype):
@parametrize_idtype @parametrize_idtype
def test_edge_prediction_sampler(idtype): def test_edge_prediction_sampler(idtype):
g = create_test_graph(idtype) g = create_test_graph(idtype)
sampler = NeighborSampler([10,10]) sampler = NeighborSampler([10, 10])
sampler = as_edge_prediction_sampler( sampler = as_edge_prediction_sampler(
sampler, negative_sampler=negative_sampler.Uniform(1)) sampler, negative_sampler=negative_sampler.Uniform(1)
)
seeds = F.copy_to(F.arange(0, 2, dtype=idtype), ctx=F.ctx()) seeds = F.copy_to(F.arange(0, 2, dtype=idtype), ctx=F.ctx())
# just a smoke test to make sure we don't fail internal assertions # just a smoke test to make sure we don't fail internal assertions
result = sampler.sample(g, {'follows': seeds}) result = sampler.sample(g, {"follows": seeds})
if __name__ == '__main__': if __name__ == "__main__":
test_edge_prediction_sampler() test_edge_prediction_sampler()
import dgl
from dgl.ops import edge_softmax
import dgl.function as fn
from collections import Counter
import math
import numpy as np
import scipy.sparse as ssp
import itertools import itertools
import math
import unittest
from collections import Counter
import backend as F import backend as F
import networkx as nx import networkx as nx
import unittest, pytest import numpy as np
from dgl import DGLError import pytest
import scipy.sparse as ssp
import test_utils import test_utils
from test_utils import parametrize_idtype, get_cases
from scipy.sparse import rand from scipy.sparse import rand
from test_utils import get_cases, parametrize_idtype
rfuncs = {'sum': fn.sum, 'max': fn.max, 'min': fn.min, 'mean': fn.mean} import dgl
fill_value = {'sum': 0, 'max': float("-inf")} import dgl.function as fn
from dgl import DGLError
from dgl.ops import edge_softmax
rfuncs = {"sum": fn.sum, "max": fn.max, "min": fn.min, "mean": fn.mean}
fill_value = {"sum": 0, "max": float("-inf")}
feat_size = 2 feat_size = 2
def create_test_heterograph(idtype): def create_test_heterograph(idtype):
# test heterograph from the docstring, plus a user -- wishes -- game relation # test heterograph from the docstring, plus a user -- wishes -- game relation
# 3 users, 2 games, 2 developers # 3 users, 2 games, 2 developers
...@@ -27,37 +31,57 @@ def create_test_heterograph(idtype): ...@@ -27,37 +31,57 @@ def create_test_heterograph(idtype):
# ('user', 'wishes', 'game'), # ('user', 'wishes', 'game'),
# ('developer', 'develops', 'game')]) # ('developer', 'develops', 'game')])
g = dgl.heterograph({ g = dgl.heterograph(
('user', 'follows', 'user'): ([0, 1, 2, 1, 1], [0, 0, 1, 1, 2]), {
('user', 'plays', 'game'): ([0, 1, 2, 1], [0, 0, 1, 1]), ("user", "follows", "user"): ([0, 1, 2, 1, 1], [0, 0, 1, 1, 2]),
('user', 'wishes', 'game'): ([0, 1, 1], [0, 0, 1]), ("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
('developer', 'develops', 'game'): ([0, 1, 0], [0, 1, 1]), ("user", "wishes", "game"): ([0, 1, 1], [0, 0, 1]),
}, idtype=idtype, device=F.ctx()) ("developer", "develops", "game"): ([0, 1, 0], [0, 1, 1]),
},
idtype=idtype,
device=F.ctx(),
)
assert g.idtype == idtype assert g.idtype == idtype
assert g.device == F.ctx() assert g.device == F.ctx()
return g return g
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
def test_edge_softmax_unidirectional(): def test_edge_softmax_unidirectional():
g = dgl.heterograph({ g = dgl.heterograph(
('A', 'AB', 'B'): ([1,2,3,1,2,3,1,2,3],[0,0,0,1,1,1,2,2,2]), {
('B', 'BB', 'B'): ([0,1,2,0,1,2,0,1,2], [0,0,0,1,1,1,2,2,2])}) ("A", "AB", "B"): (
[1, 2, 3, 1, 2, 3, 1, 2, 3],
[0, 0, 0, 1, 1, 1, 2, 2, 2],
),
("B", "BB", "B"): (
[0, 1, 2, 0, 1, 2, 0, 1, 2],
[0, 0, 0, 1, 1, 1, 2, 2, 2],
),
}
)
g = g.to(F.ctx()) g = g.to(F.ctx())
g.edges['AB'].data['x'] = F.ones(9) * 2 g.edges["AB"].data["x"] = F.ones(9) * 2
g.edges['BB'].data['x'] = F.ones(9) g.edges["BB"].data["x"] = F.ones(9)
result = dgl.ops.edge_softmax(g, {'AB': g.edges['AB'].data['x'], 'BB': g.edges['BB'].data['x']}) result = dgl.ops.edge_softmax(
g, {"AB": g.edges["AB"].data["x"], "BB": g.edges["BB"].data["x"]}
)
ab = result['A', 'AB', 'B'] ab = result["A", "AB", "B"]
bb = result['B', 'BB', 'B'] bb = result["B", "BB", "B"]
e2 = F.zeros_like(ab) + math.exp(2) / ((math.exp(2) + math.exp(1)) * 3) e2 = F.zeros_like(ab) + math.exp(2) / ((math.exp(2) + math.exp(1)) * 3)
e1 = F.zeros_like(bb) + math.exp(1) / ((math.exp(2) + math.exp(1)) * 3) e1 = F.zeros_like(bb) + math.exp(1) / ((math.exp(2) + math.exp(1)) * 3)
assert F.allclose(ab, e2) assert F.allclose(ab, e2)
assert F.allclose(bb, e1) assert F.allclose(bb, e1)
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now') @unittest.skipIf(
@pytest.mark.parametrize('g', get_cases(['clique'])) dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
@pytest.mark.parametrize('norm_by', ['src', 'dst']) )
@pytest.mark.parametrize("g", get_cases(["clique"]))
@pytest.mark.parametrize("norm_by", ["src", "dst"])
# @pytest.mark.parametrize('shp', edge_softmax_shapes) # @pytest.mark.parametrize('shp', edge_softmax_shapes)
@parametrize_idtype @parametrize_idtype
def test_edge_softmax(g, norm_by, idtype): def test_edge_softmax(g, norm_by, idtype):
...@@ -65,20 +89,20 @@ def test_edge_softmax(g, norm_by, idtype): ...@@ -65,20 +89,20 @@ def test_edge_softmax(g, norm_by, idtype):
g = create_test_heterograph(idtype) g = create_test_heterograph(idtype)
x1 = F.randn((g.num_edges('plays'),feat_size)) x1 = F.randn((g.num_edges("plays"), feat_size))
x2 = F.randn((g.num_edges('follows'),feat_size)) x2 = F.randn((g.num_edges("follows"), feat_size))
x3 = F.randn((g.num_edges('develops'),feat_size)) x3 = F.randn((g.num_edges("develops"), feat_size))
x4 = F.randn((g.num_edges('wishes'),feat_size)) x4 = F.randn((g.num_edges("wishes"), feat_size))
F.attach_grad(F.clone(x1)) F.attach_grad(F.clone(x1))
F.attach_grad(F.clone(x2)) F.attach_grad(F.clone(x2))
F.attach_grad(F.clone(x3)) F.attach_grad(F.clone(x3))
F.attach_grad(F.clone(x4)) F.attach_grad(F.clone(x4))
g['plays'].edata['eid'] = x1 g["plays"].edata["eid"] = x1
g['follows'].edata['eid'] = x2 g["follows"].edata["eid"] = x2
g['develops'].edata['eid'] = x3 g["develops"].edata["eid"] = x3
g['wishes'].edata['eid'] = x4 g["wishes"].edata["eid"] = x4
################################################################# #################################################################
# edge_softmax() on homogeneous graph # edge_softmax() on homogeneous graph
...@@ -89,12 +113,12 @@ def test_edge_softmax(g, norm_by, idtype): ...@@ -89,12 +113,12 @@ def test_edge_softmax(g, norm_by, idtype):
hm_x = F.cat((x3, x2, x1, x4), 0) hm_x = F.cat((x3, x2, x1, x4), 0)
hm_e = F.attach_grad(F.clone(hm_x)) hm_e = F.attach_grad(F.clone(hm_x))
score_hm = edge_softmax(hm_g, hm_e, norm_by=norm_by) score_hm = edge_softmax(hm_g, hm_e, norm_by=norm_by)
hm_g.edata['score'] = score_hm hm_g.edata["score"] = score_hm
ht_g = dgl.to_heterogeneous(hm_g, g.ntypes, g.etypes) ht_g = dgl.to_heterogeneous(hm_g, g.ntypes, g.etypes)
r1 = ht_g.edata['score'][('user', 'plays', 'game')] r1 = ht_g.edata["score"][("user", "plays", "game")]
r2 = ht_g.edata['score'][('user', 'follows', 'user')] r2 = ht_g.edata["score"][("user", "follows", "user")]
r3 = ht_g.edata['score'][('developer', 'develops', 'game')] r3 = ht_g.edata["score"][("developer", "develops", "game")]
r4 = ht_g.edata['score'][('user', 'wishes', 'game')] r4 = ht_g.edata["score"][("user", "wishes", "game")]
F.backward(F.reduce_sum(r1) + F.reduce_sum(r2)) F.backward(F.reduce_sum(r1) + F.reduce_sum(r2))
grad_edata_hm = F.grad(hm_e) grad_edata_hm = F.grad(hm_e)
...@@ -106,18 +130,22 @@ def test_edge_softmax(g, norm_by, idtype): ...@@ -106,18 +130,22 @@ def test_edge_softmax(g, norm_by, idtype):
e2 = F.attach_grad(F.clone(x2)) e2 = F.attach_grad(F.clone(x2))
e3 = F.attach_grad(F.clone(x3)) e3 = F.attach_grad(F.clone(x3))
e4 = F.attach_grad(F.clone(x4)) e4 = F.attach_grad(F.clone(x4))
e = {('user', 'follows', 'user'): e2, e = {
('user', 'plays', 'game'): e1, ("user", "follows", "user"): e2,
('user', 'wishes', 'game'): e4, ("user", "plays", "game"): e1,
('developer', 'develops', 'game'): e3} ("user", "wishes", "game"): e4,
("developer", "develops", "game"): e3,
}
with F.record_grad(): with F.record_grad():
score = edge_softmax(g, e, norm_by=norm_by) score = edge_softmax(g, e, norm_by=norm_by)
r5 = score[('user', 'plays', 'game')] r5 = score[("user", "plays", "game")]
r6 = score[('user', 'follows', 'user')] r6 = score[("user", "follows", "user")]
r7 = score[('developer', 'develops', 'game')] r7 = score[("developer", "develops", "game")]
r8 = score[('user', 'wishes', 'game')] r8 = score[("user", "wishes", "game")]
F.backward(F.reduce_sum(r5) + F.reduce_sum(r6)) F.backward(F.reduce_sum(r5) + F.reduce_sum(r6))
grad_edata_ht = F.cat((F.grad(e3), F.grad(e2), F.grad(e1), F.grad(e4)), 0) grad_edata_ht = F.cat(
(F.grad(e3), F.grad(e2), F.grad(e1), F.grad(e4)), 0
)
# correctness check # correctness check
assert F.allclose(r1, r5) assert F.allclose(r1, r5)
assert F.allclose(r2, r6) assert F.allclose(r2, r6)
...@@ -125,5 +153,6 @@ def test_edge_softmax(g, norm_by, idtype): ...@@ -125,5 +153,6 @@ def test_edge_softmax(g, norm_by, idtype):
assert F.allclose(r4, r8) assert F.allclose(r4, r8)
assert F.allclose(grad_edata_hm, grad_edata_ht) assert F.allclose(grad_edata_hm, grad_edata_ht)
if __name__ == '__main__':
if __name__ == "__main__":
test_edge_softmax_unidirectional() test_edge_softmax_unidirectional()
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment