Unverified Commit 98325b10 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4691)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent c24e285a
"""Views of DGLGraph."""
from __future__ import absolute_import
from collections import namedtuple, defaultdict
from collections import defaultdict, namedtuple
from collections.abc import MutableMapping
from .base import ALL, DGLError
from . import backend as F
from .base import ALL, DGLError
from .frame import LazyFeature
NodeSpace = namedtuple('NodeSpace', ['data'])
EdgeSpace = namedtuple('EdgeSpace', ['data'])
NodeSpace = namedtuple("NodeSpace", ["data"])
EdgeSpace = namedtuple("EdgeSpace", ["data"])
class HeteroNodeView(object):
"""A NodeView class to act as G.nodes for a DGLHeteroGraph."""
__slots__ = ['_graph', '_typeid_getter']
__slots__ = ["_graph", "_typeid_getter"]
def __init__(self, graph, typeid_getter):
self._graph = graph
......@@ -23,8 +24,9 @@ class HeteroNodeView(object):
def __getitem__(self, key):
if isinstance(key, slice):
# slice
if not (key.start is None and key.stop is None
and key.step is None):
if not (
key.start is None and key.stop is None and key.step is None
):
raise DGLError('Currently only full slice ":" is supported')
nodes = ALL
ntype = None
......@@ -38,20 +40,25 @@ class HeteroNodeView(object):
ntype = None
ntid = self._typeid_getter(ntype)
return NodeSpace(
data=HeteroNodeDataView(
self._graph, ntype, ntid, nodes))
data=HeteroNodeDataView(self._graph, ntype, ntid, nodes)
)
def __call__(self, ntype=None):
"""Return the nodes."""
ntid = self._typeid_getter(ntype)
ret = F.arange(0, self._graph._graph.number_of_nodes(ntid),
dtype=self._graph.idtype, ctx=self._graph.device)
ret = F.arange(
0,
self._graph._graph.number_of_nodes(ntid),
dtype=self._graph.idtype,
ctx=self._graph.device,
)
return ret
class HeteroNodeDataView(MutableMapping):
"""The data view class when G.ndata[ntype] is called."""
__slots__ = ['_graph', '_ntype', '_ntid', '_nodes']
__slots__ = ["_graph", "_ntype", "_ntid", "_nodes"]
def __init__(self, graph, ntype, ntid, nodes):
self._graph = graph
......@@ -63,9 +70,9 @@ class HeteroNodeDataView(MutableMapping):
if isinstance(self._ntype, list):
ret = {}
for (i, ntype) in enumerate(self._ntype):
value = self._graph._get_n_repr(
self._ntid[i], self._nodes).get(
key, None)
value = self._graph._get_n_repr(self._ntid[i], self._nodes).get(
key, None
)
if value is not None:
ret[ntype] = value
return ret
......@@ -76,17 +83,19 @@ class HeteroNodeDataView(MutableMapping):
if isinstance(val, LazyFeature):
self._graph._node_frames[self._ntid][key] = val
elif isinstance(self._ntype, list):
assert isinstance(val, dict), \
'Current HeteroNodeDataView has multiple node types, ' \
'please passing the node type and the corresponding data through a dict.'
assert isinstance(val, dict), (
"Current HeteroNodeDataView has multiple node types, "
"please passing the node type and the corresponding data through a dict."
)
for (ntype, data) in val.items():
ntid = self._graph.get_ntype_id(ntype)
self._graph._set_n_repr(ntid, self._nodes, {key: data})
else:
assert isinstance(val, dict) is False, \
'The HeteroNodeDataView has only one node type. ' \
'please pass a tensor directly'
assert isinstance(val, dict) is False, (
"The HeteroNodeDataView has only one node type. "
"please pass a tensor directly"
)
self._graph._set_n_repr(self._ntid, self._nodes, {key: val})
def __delitem__(self, key):
......@@ -108,8 +117,10 @@ class HeteroNodeDataView(MutableMapping):
else:
ret = self._graph._get_n_repr(self._ntid, self._nodes)
if as_dict:
ret = {key: ret[key]
for key in self._graph._node_frames[self._ntid]}
ret = {
key: ret[key]
for key in self._graph._node_frames[self._ntid]
}
return ret
def __len__(self):
......@@ -130,7 +141,8 @@ class HeteroNodeDataView(MutableMapping):
class HeteroEdgeView(object):
"""A EdgeView class to act as G.edges for a DGLHeteroGraph."""
__slots__ = ['_graph']
__slots__ = ["_graph"]
def __init__(self, graph):
self._graph = graph
......@@ -138,8 +150,9 @@ class HeteroEdgeView(object):
def __getitem__(self, key):
if isinstance(key, slice):
# slice
if not (key.start is None and key.stop is None
and key.step is None):
if not (
key.start is None and key.stop is None and key.step is None
):
raise DGLError('Currently only full slice ":" is supported')
edges = ALL
etype = None
......@@ -168,23 +181,26 @@ class HeteroEdgeView(object):
class HeteroEdgeDataView(MutableMapping):
"""The data view class when G.edata[etype] is called."""
__slots__ = ['_graph', '_etype', '_etid', '_edges']
__slots__ = ["_graph", "_etype", "_etid", "_edges"]
def __init__(self, graph, etype, edges):
self._graph = graph
self._etype = etype
self._etid = [self._graph.get_etype_id(t) for t in etype] \
if isinstance(etype, list) \
self._etid = (
[self._graph.get_etype_id(t) for t in etype]
if isinstance(etype, list)
else self._graph.get_etype_id(etype)
)
self._edges = edges
def __getitem__(self, key):
if isinstance(self._etype, list):
ret = {}
for (i, etype) in enumerate(self._etype):
value = self._graph._get_e_repr(
self._etid[i], self._edges).get(
key, None)
value = self._graph._get_e_repr(self._etid[i], self._edges).get(
key, None
)
if value is not None:
ret[etype] = value
return ret
......@@ -195,17 +211,19 @@ class HeteroEdgeDataView(MutableMapping):
if isinstance(val, LazyFeature):
self._graph._edge_frames[self._etid][key] = val
elif isinstance(self._etype, list):
assert isinstance(val, dict), \
'Current HeteroEdgeDataView has multiple edge types, ' \
'please pass the edge type and the corresponding data through a dict.'
assert isinstance(val, dict), (
"Current HeteroEdgeDataView has multiple edge types, "
"please pass the edge type and the corresponding data through a dict."
)
for (etype, data) in val.items():
etid = self._graph.get_etype_id(etype)
self._graph._set_e_repr(etid, self._edges, {key: data})
else:
assert isinstance(val, dict) is False, \
'The HeteroEdgeDataView has only one edge type. ' \
'please pass a tensor directly'
assert isinstance(val, dict) is False, (
"The HeteroEdgeDataView has only one edge type. "
"please pass a tensor directly"
)
self._graph._set_e_repr(self._etid, self._edges, {key: val})
def __delitem__(self, key):
......@@ -227,8 +245,10 @@ class HeteroEdgeDataView(MutableMapping):
else:
ret = self._graph._get_e_repr(self._etid, self._edges)
if as_dict:
ret = {key: ret[key]
for key in self._graph._edge_frames[self._etid]}
ret = {
key: ret[key]
for key in self._graph._edge_frames[self._etid]
}
return ret
def __len__(self):
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import glob
import os
import platform
import sysconfig
import shutil
import glob
import sys
import sysconfig
from setuptools import find_packages
from setuptools.dist import Distribution
# need to use distutils.core for correct placement of cython dll
if '--inplace' in sys.argv:
if "--inplace" in sys.argv:
from distutils.core import setup
from distutils.extension import Extension
else:
......@@ -31,34 +31,35 @@ def get_lib_path():
"""Get library path, name and version"""
# We can not import `libinfo.py` in setup.py directly since __init__.py
# Will be invoked which introduces dependences
libinfo_py = os.path.join(CURRENT_DIR, './dgl/_ffi/libinfo.py')
libinfo = {'__file__': libinfo_py}
libinfo_py = os.path.join(CURRENT_DIR, "./dgl/_ffi/libinfo.py")
libinfo = {"__file__": libinfo_py}
exec(
compile(open(libinfo_py, "rb").read(), libinfo_py, 'exec'),
compile(open(libinfo_py, "rb").read(), libinfo_py, "exec"),
libinfo,
libinfo,
libinfo)
version = libinfo['__version__']
)
version = libinfo["__version__"]
lib_path = libinfo['find_lib_path']()
lib_path = libinfo["find_lib_path"]()
libs = [lib_path[0]]
return libs, version
def get_ta_lib_pattern():
if sys.platform.startswith('linux'):
ta_lib_pattern = 'libtensoradapter_*.so'
elif sys.platform.startswith('darwin'):
ta_lib_pattern = 'libtensoradapter_*.dylib'
elif sys.platform.startswith('win'):
ta_lib_pattern = 'tensoradapter_*.dll'
if sys.platform.startswith("linux"):
ta_lib_pattern = "libtensoradapter_*.so"
elif sys.platform.startswith("darwin"):
ta_lib_pattern = "libtensoradapter_*.dylib"
elif sys.platform.startswith("win"):
ta_lib_pattern = "tensoradapter_*.dll"
else:
raise NotImplementedError('Unsupported system: %s' % sys.platform)
raise NotImplementedError("Unsupported system: %s" % sys.platform)
return ta_lib_pattern
LIBS, VERSION = get_lib_path()
BACKENDS = ['pytorch']
BACKENDS = ["pytorch"]
TA_LIB_PATTERN = get_ta_lib_pattern()
......@@ -78,11 +79,9 @@ def cleanup():
for backend in BACKENDS:
for ta_path in glob.glob(
os.path.join(
CURRENT_DIR,
"dgl",
"tensoradapter",
backend,
TA_LIB_PATTERN)):
CURRENT_DIR, "dgl", "tensoradapter", backend, TA_LIB_PATTERN
)
):
try:
os.remove(ta_path)
except BaseException:
......@@ -91,17 +90,21 @@ def cleanup():
def config_cython():
"""Try to configure cython and return cython configuration"""
if sys.platform.startswith('win'):
print("WARNING: Cython is not supported on Windows, will compile without cython module")
if sys.platform.startswith("win"):
print(
"WARNING: Cython is not supported on Windows, will compile without cython module"
)
return []
sys_cflags = sysconfig.get_config_var("CFLAGS")
if "i386" in sys_cflags and "x86_64" in sys_cflags:
print(
"WARNING: Cython library may not be compiled correctly with both i386 and x64")
"WARNING: Cython library may not be compiled correctly with both i386 and x64"
)
return []
try:
from Cython.Build import cythonize
# from setuptools.extension import Extension
if sys.version_info >= (3, 0):
subdir = "_cy3"
......@@ -109,32 +112,38 @@ def config_cython():
subdir = "_cy2"
ret = []
path = "dgl/_ffi/_cython"
library_dirs = ['dgl', '../build/Release', '../build']
libraries = ['dgl']
library_dirs = ["dgl", "../build/Release", "../build"]
libraries = ["dgl"]
for fn in os.listdir(path):
if not fn.endswith(".pyx"):
continue
ret.append(Extension(
"dgl._ffi.%s.%s" % (subdir, fn[:-4]),
["dgl/_ffi/_cython/%s" % fn],
include_dirs=["../include/",
"../third_party/dmlc-core/include",
"../third_party/dlpack/include",
],
library_dirs=library_dirs,
libraries=libraries,
# Crashes without this flag with GCC 5.3.1
extra_compile_args=["-std=c++11"],
language="c++"))
ret.append(
Extension(
"dgl._ffi.%s.%s" % (subdir, fn[:-4]),
["dgl/_ffi/_cython/%s" % fn],
include_dirs=[
"../include/",
"../third_party/dmlc-core/include",
"../third_party/dlpack/include",
],
library_dirs=library_dirs,
libraries=libraries,
# Crashes without this flag with GCC 5.3.1
extra_compile_args=["-std=c++11"],
language="c++",
)
)
return cythonize(ret, force=True)
except ImportError:
print("WARNING: Cython is not installed, will compile without cython module")
print(
"WARNING: Cython is not installed, will compile without cython module"
)
return []
include_libs = False
wheel_include_libs = False
if "bdist_wheel" in sys.argv or os.getenv('CONDA_BUILD'):
if "bdist_wheel" in sys.argv or os.getenv("CONDA_BUILD"):
wheel_include_libs = True
elif "clean" in sys.argv:
cleanup()
......@@ -147,78 +156,76 @@ setup_kwargs = {}
if wheel_include_libs:
with open("MANIFEST.in", "w") as fo:
for path in LIBS:
shutil.copy(path, os.path.join(CURRENT_DIR, 'dgl'))
shutil.copy(path, os.path.join(CURRENT_DIR, "dgl"))
dir_, libname = os.path.split(path)
fo.write("include dgl/%s\n" % libname)
for backend in BACKENDS:
for ta_path in glob.glob(
os.path.join(
dir_,
"tensoradapter",
backend,
TA_LIB_PATTERN)):
os.path.join(dir_, "tensoradapter", backend, TA_LIB_PATTERN)
):
ta_name = os.path.basename(ta_path)
os.makedirs(
os.path.join(
CURRENT_DIR,
'dgl',
'tensoradapter',
backend),
exist_ok=True)
os.path.join(CURRENT_DIR, "dgl", "tensoradapter", backend),
exist_ok=True,
)
shutil.copy(
os.path.join(dir_, 'tensoradapter', backend, ta_name),
os.path.join(CURRENT_DIR, 'dgl', 'tensoradapter', backend))
os.path.join(dir_, "tensoradapter", backend, ta_name),
os.path.join(CURRENT_DIR, "dgl", "tensoradapter", backend),
)
fo.write(
"include dgl/tensoradapter/%s/%s\n" %
(backend, ta_name))
"include dgl/tensoradapter/%s/%s\n" % (backend, ta_name)
)
setup_kwargs = {
"include_package_data": True
}
setup_kwargs = {"include_package_data": True}
# For source tree setup
# Conda build also includes the binary library
if include_libs:
rpath = [os.path.relpath(path, CURRENT_DIR) for path in LIBS]
data_files = [('dgl', rpath)]
data_files = [("dgl", rpath)]
for path in LIBS:
for backend in BACKENDS:
data_files.append((
'dgl/tensoradapter/%s' % backend,
glob.glob(os.path.join(
os.path.dirname(os.path.relpath(path, CURRENT_DIR)),
'tensoradapter', backend, TA_LIB_PATTERN))))
setup_kwargs = {
"include_package_data": True,
"data_files": data_files
}
data_files.append(
(
"dgl/tensoradapter/%s" % backend,
glob.glob(
os.path.join(
os.path.dirname(os.path.relpath(path, CURRENT_DIR)),
"tensoradapter",
backend,
TA_LIB_PATTERN,
)
),
)
)
setup_kwargs = {"include_package_data": True, "data_files": data_files}
setup(
name='dgl' + os.getenv('DGL_PACKAGE_SUFFIX', ''),
name="dgl" + os.getenv("DGL_PACKAGE_SUFFIX", ""),
version=VERSION,
description='Deep Graph Library',
description="Deep Graph Library",
zip_safe=False,
maintainer='DGL Team',
maintainer_email='wmjlyjemaine@gmail.com',
maintainer="DGL Team",
maintainer_email="wmjlyjemaine@gmail.com",
packages=find_packages(),
install_requires=[
'numpy>=1.14.0',
'scipy>=1.1.0',
'networkx>=2.1',
'requests>=2.19.0',
'tqdm',
'psutil>=5.8.0',
"numpy>=1.14.0",
"scipy>=1.1.0",
"networkx>=2.1",
"requests>=2.19.0",
"tqdm",
"psutil>=5.8.0",
],
url='https://github.com/dmlc/dgl',
url="https://github.com/dmlc/dgl",
distclass=BinaryDistribution,
ext_modules=config_cython(),
classifiers=[
'Development Status :: 3 - Alpha',
'Programming Language :: Python :: 3',
'License :: OSI Approved :: Apache Software License',
"Development Status :: 3 - Alpha",
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
],
license='APACHE',
license="APACHE",
**setup_kwargs
)
......
......@@ -8,10 +8,11 @@ List of affected files:
"""
import os
import re
# current version
# We use the version of the incoming release for code
# that is under development
__version__ = "0.10" + os.getenv('DGL_PRERELEASE', '')
__version__ = "0.10" + os.getenv("DGL_PRERELEASE", "")
print(__version__)
# Implementations
......@@ -47,22 +48,24 @@ def main():
curr_dir = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_dir, ".."))
# python path
update(os.path.join(proj_root, "python", "dgl", "_ffi", "libinfo.py"),
r"(?<=__version__ = \")[.0-9a-z]+", __version__)
update(
os.path.join(proj_root, "python", "dgl", "_ffi", "libinfo.py"),
r"(?<=__version__ = \")[.0-9a-z]+",
__version__,
)
# C++ header
update(
os.path.join(
proj_root,
"include",
"dgl",
"runtime",
"c_runtime_api.h"),
"(?<=DGL_VERSION \")[.0-9a-z]+",
__version__)
os.path.join(proj_root, "include", "dgl", "runtime", "c_runtime_api.h"),
'(?<=DGL_VERSION ")[.0-9a-z]+',
__version__,
)
# conda
for path in ["dgl"]:
update(os.path.join(proj_root, "conda", path, "meta.yaml"),
"(?<=version: \")[.0-9a-z]+", __version__)
update(
os.path.join(proj_root, "conda", path, "meta.yaml"),
'(?<=version: ")[.0-9a-z]+',
__version__,
)
if __name__ == "__main__":
......
import torch
import os
import torch
cmake_prefix_path = getattr(
torch.utils,
"cmake_prefix_path",
os.path.join(os.path.dirname(torch.__file__), "share", "cmake"))
version = torch.__version__.split('+')[0]
print(';'.join([cmake_prefix_path, version]))
os.path.join(os.path.dirname(torch.__file__), "share", "cmake"),
)
version = torch.__version__.split("+")[0]
print(";".join([cmake_prefix_path, version]))
from dgl.backend import *
from dgl.nn import *
from . import backend_unittest
import os
import importlib
import os
import sys
import numpy as np
mod = importlib.import_module('.%s' % backend_name, __name__)
from dgl.backend import *
from dgl.nn import *
from . import backend_unittest
mod = importlib.import_module(".%s" % backend_name, __name__)
thismod = sys.modules[__name__]
for api in backend_unittest.__dict__.keys():
if api.startswith('__'):
if api.startswith("__"):
continue
elif callable(mod.__dict__[api]):
# Tensor APIs used in unit tests MUST be supported across all backends
......@@ -26,39 +29,51 @@ _arange = arange
_full = full
_full_1d = full_1d
_softmax = softmax
_default_context_str = os.getenv('DGLTESTDEV', 'cpu')
_default_context_str = os.getenv("DGLTESTDEV", "cpu")
_context_dict = {
'cpu': cpu(),
'gpu': cuda(),
}
"cpu": cpu(),
"gpu": cuda(),
}
_default_context = _context_dict[_default_context_str]
def ctx():
return _default_context
def gpu_ctx():
return (_default_context_str == 'gpu')
return _default_context_str == "gpu"
def zeros(shape, dtype=float32, ctx=_default_context):
return _zeros(shape, dtype, ctx)
def ones(shape, dtype=float32, ctx=_default_context):
return _ones(shape, dtype, ctx)
def randn(shape):
return copy_to(_randn(shape), _default_context)
def tensor(data, dtype=None):
return copy_to(_tensor(data, dtype), _default_context)
def arange(start, stop, dtype=int64, ctx=None):
return _arange(start, stop, dtype, ctx if ctx is not None else _default_context)
return _arange(
start, stop, dtype, ctx if ctx is not None else _default_context
)
def full(shape, fill_value, dtype, ctx=_default_context):
return _full(shape, fill_value, dtype, ctx)
def full_1d(length, fill_value, dtype, ctx=_default_context):
return _full_1d(length, fill_value, dtype, ctx)
def softmax(x, dim):
return _softmax(x, dim)
......@@ -5,102 +5,127 @@ unit testing, other than the ones used in the framework itself.
###############################################################################
# Tensor, data type and context interfaces
def cuda():
"""Context object for CUDA."""
pass
def is_cuda_available():
"""Check whether CUDA is available."""
pass
###############################################################################
# Tensor functions on feature data
# --------------------------------
# These functions are performance critical, so it's better to have efficient
# implementation in each framework.
def array_equal(a, b):
"""Check whether the two tensors are *exactly* equal."""
pass
def allclose(a, b, rtol=1e-4, atol=1e-4):
"""Check whether the two tensors are numerically close to each other."""
pass
def randn(shape):
"""Generate a tensor with elements from standard normal distribution."""
pass
def full(shape, fill_value, dtype, ctx):
pass
def narrow_row_set(x, start, stop, new):
"""Set a slice of the given tensor to a new value."""
pass
def sparse_to_numpy(x):
"""Convert a sparse tensor to a numpy array."""
pass
def clone(x):
pass
def reduce_sum(x):
"""Sums all the elements into a single scalar."""
pass
def softmax(x, dim):
"""Softmax Operation on Tensors"""
pass
def spmm(x, y):
"""Sparse dense matrix multiply"""
pass
def add(a, b):
"""Compute a + b"""
pass
def sub(a, b):
"""Compute a - b"""
pass
def mul(a, b):
"""Compute a * b"""
pass
def div(a, b):
"""Compute a / b"""
pass
def sum(x, dim, keepdims=False):
"""Computes the sum of array elements over given axes"""
pass
def max(x, dim):
"""Computes the max of array elements over given axes"""
pass
def min(x, dim):
"""Computes the min of array elements over given axes"""
pass
def prod(x, dim):
"""Computes the prod of array elements over given axes"""
pass
def matmul(a, b):
"""Compute Matrix Multiplication between a and b"""
pass
def dot(a, b):
"""Compute Dot between a and b"""
pass
def abs(a):
"""Compute the absolute value of a"""
pass
###############################################################################
# Tensor functions used *only* on index tensor
# ----------------
......
from __future__ import absolute_import
import numpy as np
import mxnet as mx
import mxnet.ndarray as nd
import numpy as np
def cuda():
return mx.gpu()
def is_cuda_available():
# TODO: Does MXNet have a convenient function to test GPU availability/compilation?
try:
......@@ -15,65 +17,86 @@ def is_cuda_available():
except mx.MXNetError:
return False
def array_equal(a, b):
return nd.equal(a, b).asnumpy().all()
def allclose(a, b, rtol=1e-4, atol=1e-4):
return np.allclose(a.asnumpy(), b.asnumpy(), rtol=rtol, atol=atol)
def randn(shape):
return nd.random.randn(*shape)
def full(shape, fill_value, dtype, ctx):
return nd.full(shape, fill_value, dtype=dtype, ctx=ctx)
def narrow_row_set(x, start, stop, new):
x[start:stop] = new
def sparse_to_numpy(x):
return x.asscipy().todense().A
def clone(x):
return x.copy()
def reduce_sum(x):
return x.sum()
def softmax(x, dim):
return nd.softmax(x, axis=dim)
def spmm(x, y):
return nd.dot(x, y)
def add(a, b):
return a + b
def sub(a, b):
return a - b
def mul(a, b):
return a * b
def div(a, b):
return a / b
def sum(x, dim, keepdims=False):
return x.sum(dim, keepdims=keepdims)
def max(x, dim):
return x.max(dim)
def min(x, dim):
return x.min(dim)
def prod(x, dim):
return x.prod(dim)
def matmul(a, b):
return nd.dot(a, b)
def dot(a, b):
return nd.sum(mul(a, b), axis=-1)
def abs(a):
return nd.abs(a)
......@@ -2,72 +2,94 @@ from __future__ import absolute_import
import torch as th
def cuda():
return th.device('cuda:0')
return th.device("cuda:0")
def is_cuda_available():
return th.cuda.is_available()
def array_equal(a, b):
return th.equal(a.cpu(), b.cpu())
def allclose(a, b, rtol=1e-4, atol=1e-4):
return th.allclose(a.float().cpu(),
b.float().cpu(), rtol=rtol, atol=atol)
return th.allclose(a.float().cpu(), b.float().cpu(), rtol=rtol, atol=atol)
def randn(shape):
return th.randn(*shape)
def full(shape, fill_value, dtype, ctx):
return th.full(shape, fill_value, dtype=dtype, device=ctx)
def narrow_row_set(x, start, stop, new):
x[start:stop] = new
def sparse_to_numpy(x):
return x.to_dense().numpy()
def clone(x):
return x.clone()
def reduce_sum(x):
return x.sum()
def softmax(x, dim):
return th.softmax(x, dim)
def spmm(x, y):
return th.spmm(x, y)
def add(a, b):
return a + b
def sub(a, b):
return a - b
def mul(a, b):
return a * b
def div(a, b):
return a / b
def sum(x, dim, keepdims=False):
return x.sum(dim, keepdims=keepdims)
def max(x, dim):
return x.max(dim)[0]
def min(x, dim):
return x.min(dim)[0]
def prod(x, dim):
return x.prod(dim)
def matmul(a, b):
return a @ b
def dot(a, b):
return sum(mul(a, b), dim=-1)
def abs(a):
return a.abs()
......@@ -6,7 +6,7 @@ from scipy.sparse import coo_matrix
def cuda():
return '/gpu:0'
return "/gpu:0"
def is_cuda_available():
......@@ -18,8 +18,12 @@ def array_equal(a, b):
def allclose(a, b, rtol=1e-4, atol=1e-4):
return np.allclose(tf.convert_to_tensor(a).numpy(),
tf.convert_to_tensor(b).numpy(), rtol=rtol, atol=atol)
return np.allclose(
tf.convert_to_tensor(a).numpy(),
tf.convert_to_tensor(b).numpy(),
rtol=rtol,
atol=atol,
)
def randn(shape):
......@@ -97,5 +101,6 @@ def matmul(a, b):
def dot(a, b):
return sum(mul(a, b), dim=-1)
def abs(a):
return tf.abs(a)
import dgl
import dgl.function as fn
from collections import Counter
import numpy as np
import scipy.sparse as ssp
import itertools
import unittest
from collections import Counter
from itertools import product
import backend as F
import networkx as nx
import unittest, pytest
from dgl import DGLError
import numpy as np
import pytest
import scipy.sparse as ssp
import test_utils
from test_utils import parametrize_idtype, get_cases
from scipy.sparse import rand
from test_utils import get_cases, parametrize_idtype
import dgl
import dgl.function as fn
from dgl import DGLError
rfuncs = {'sum': fn.sum, 'max': fn.max, 'min': fn.min, 'mean': fn.mean}
fill_value = {'sum': 0, 'max': float("-inf")}
rfuncs = {"sum": fn.sum, "max": fn.max, "min": fn.min, "mean": fn.mean}
fill_value = {"sum": 0, "max": float("-inf")}
feat_size = 2
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
def create_test_heterograph(idtype):
# test heterograph from the docstring, plus a user -- wishes -- game relation
# 3 users, 2 games, 2 developers
......@@ -28,12 +33,16 @@ def create_test_heterograph(idtype):
# ('user', 'wishes', 'game'),
# ('developer', 'develops', 'game')])
g = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1, 2, 1], [0, 0, 1, 1]),
('user', 'plays', 'game'): ([0, 1, 2, 1], [0, 0, 1, 1]),
('user', 'wishes', 'game'): ([0, 1, 1], [0, 0, 1]),
('developer', 'develops', 'game'): ([0, 1, 0], [0, 1, 1]),
}, idtype=idtype, device=F.ctx())
g = dgl.heterograph(
{
("user", "follows", "user"): ([0, 1, 2, 1], [0, 0, 1, 1]),
("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
("user", "wishes", "game"): ([0, 1, 1], [0, 0, 1]),
("developer", "develops", "game"): ([0, 1, 0], [0, 1, 1]),
},
idtype=idtype,
device=F.ctx(),
)
assert g.idtype == idtype
assert g.device == F.ctx()
return g
......@@ -45,49 +54,53 @@ def test_unary_copy_u(idtype):
g = create_test_heterograph(idtype)
x1 = F.randn((g.num_nodes('user'), feat_size))
x2 = F.randn((g.num_nodes('developer'), feat_size))
x1 = F.randn((g.num_nodes("user"), feat_size))
x2 = F.randn((g.num_nodes("developer"), feat_size))
F.attach_grad(x1)
F.attach_grad(x2)
g.nodes['user'].data['h'] = x1
g.nodes['developer'].data['h'] = x2
g.nodes["user"].data["h"] = x1
g.nodes["developer"].data["h"] = x2
#################################################################
# apply_edges() is called on each relation type separately
#################################################################
with F.record_grad():
[g.apply_edges(fn.copy_u('h', 'm'), etype = rel)
for rel in g.canonical_etypes]
r1 = g['plays'].edata['m']
[
g.apply_edges(fn.copy_u("h", "m"), etype=rel)
for rel in g.canonical_etypes
]
r1 = g["plays"].edata["m"]
F.backward(r1, F.ones(r1.shape))
n_grad1 = F.grad(g.ndata['h']['user'])
n_grad1 = F.grad(g.ndata["h"]["user"])
# TODO (Israt): clear not working
g.edata['m'].clear()
g.edata["m"].clear()
#################################################################
# apply_edges() is called on all relation types
#################################################################
g.apply_edges(fn.copy_u('h', 'm'))
r2 = g['plays'].edata['m']
g.apply_edges(fn.copy_u("h", "m"))
r2 = g["plays"].edata["m"]
F.backward(r2, F.ones(r2.shape))
n_grad2 = F.grad(g.nodes['user'].data['h'])
n_grad2 = F.grad(g.nodes["user"].data["h"])
# correctness check
def _print_error(a, b):
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())):
for i, (x, y) in enumerate(
zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
):
if not np.allclose(x, y):
print('@{} {} v.s. {}'.format(i, x, y))
print("@{} {} v.s. {}".format(i, x, y))
if not F.allclose(r1, r2):
_print_error(r1, r2)
assert F.allclose(r1, r2)
if not F.allclose(n_grad1, n_grad2):
print('node grad')
print("node grad")
_print_error(n_grad1, n_grad2)
assert(F.allclose(n_grad1, n_grad2))
assert F.allclose(n_grad1, n_grad2)
_test(fn.copy_u)
......@@ -99,51 +112,55 @@ def test_unary_copy_e(idtype):
g = create_test_heterograph(idtype)
feat_size = 2
x1 = F.randn((4,feat_size))
x2 = F.randn((4,feat_size))
x3 = F.randn((3,feat_size))
x4 = F.randn((3,feat_size))
x1 = F.randn((4, feat_size))
x2 = F.randn((4, feat_size))
x3 = F.randn((3, feat_size))
x4 = F.randn((3, feat_size))
F.attach_grad(x1)
F.attach_grad(x2)
F.attach_grad(x3)
F.attach_grad(x4)
g['plays'].edata['eid'] = x1
g['follows'].edata['eid'] = x2
g['develops'].edata['eid'] = x3
g['wishes'].edata['eid'] = x4
g["plays"].edata["eid"] = x1
g["follows"].edata["eid"] = x2
g["develops"].edata["eid"] = x3
g["wishes"].edata["eid"] = x4
#################################################################
# apply_edges() is called on each relation type separately
#################################################################
with F.record_grad():
[g.apply_edges(fn.copy_e('eid', 'm'), etype = rel)
for rel in g.canonical_etypes]
r1 = g['develops'].edata['m']
[
g.apply_edges(fn.copy_e("eid", "m"), etype=rel)
for rel in g.canonical_etypes
]
r1 = g["develops"].edata["m"]
F.backward(r1, F.ones(r1.shape))
e_grad1 = F.grad(g['develops'].edata['eid'])
e_grad1 = F.grad(g["develops"].edata["eid"])
#################################################################
# apply_edges() is called on all relation types
#################################################################
g.apply_edges(fn.copy_e('eid', 'm'))
r2 = g['develops'].edata['m']
g.apply_edges(fn.copy_e("eid", "m"))
r2 = g["develops"].edata["m"]
F.backward(r2, F.ones(r2.shape))
e_grad2 = F.grad(g['develops'].edata['eid'])
e_grad2 = F.grad(g["develops"].edata["eid"])
# # correctness check
def _print_error(a, b):
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())):
for i, (x, y) in enumerate(
zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
):
if not np.allclose(x, y):
print('@{} {} v.s. {}'.format(i, x, y))
print("@{} {} v.s. {}".format(i, x, y))
if not F.allclose(r1, r2):
_print_error(r1, r2)
assert F.allclose(r1, r2)
if not F.allclose(e_grad1, e_grad2):
print('edge grad')
print("edge grad")
_print_error(e_grad1, e_grad2)
assert(F.allclose(e_grad1, e_grad2))
assert F.allclose(e_grad1, e_grad2)
_test(fn.copy_e)
......@@ -154,14 +171,14 @@ def test_binary_op(idtype):
g = create_test_heterograph(idtype)
n1 = F.randn((g.num_nodes('user'), feat_size))
n2 = F.randn((g.num_nodes('developer'), feat_size))
n3 = F.randn((g.num_nodes('game'), feat_size))
n1 = F.randn((g.num_nodes("user"), feat_size))
n2 = F.randn((g.num_nodes("developer"), feat_size))
n3 = F.randn((g.num_nodes("game"), feat_size))
x1 = F.randn((g.num_edges('plays'),feat_size))
x2 = F.randn((g.num_edges('follows'),feat_size))
x3 = F.randn((g.num_edges('develops'),feat_size))
x4 = F.randn((g.num_edges('wishes'),feat_size))
x1 = F.randn((g.num_edges("plays"), feat_size))
x2 = F.randn((g.num_edges("follows"), feat_size))
x3 = F.randn((g.num_edges("develops"), feat_size))
x4 = F.randn((g.num_edges("wishes"), feat_size))
builtin_msg_name = "{}_{}_{}".format(lhs, binary_op, rhs)
builtin_msg = getattr(fn, builtin_msg_name)
......@@ -173,25 +190,27 @@ def test_binary_op(idtype):
F.attach_grad(n1)
F.attach_grad(n2)
F.attach_grad(n3)
g.nodes['user'].data['h'] = n1
g.nodes['developer'].data['h'] = n2
g.nodes['game'].data['h'] = n3
g.nodes["user"].data["h"] = n1
g.nodes["developer"].data["h"] = n2
g.nodes["game"].data["h"] = n3
F.attach_grad(x1)
F.attach_grad(x2)
F.attach_grad(x3)
F.attach_grad(x4)
g['plays'].edata['h'] = x1
g['follows'].edata['h'] = x2
g['develops'].edata['h'] = x3
g['wishes'].edata['h'] = x4
g["plays"].edata["h"] = x1
g["follows"].edata["h"] = x2
g["develops"].edata["h"] = x3
g["wishes"].edata["h"] = x4
with F.record_grad():
[g.apply_edges(builtin_msg('h', 'h', 'm'), etype = rel)
for rel in g.canonical_etypes]
r1 = g['plays'].edata['m']
[
g.apply_edges(builtin_msg("h", "h", "m"), etype=rel)
for rel in g.canonical_etypes
]
r1 = g["plays"].edata["m"]
loss = F.sum(r1.view(-1), 0)
F.backward(loss)
n_grad1 = F.grad(g.nodes['game'].data['h'])
n_grad1 = F.grad(g.nodes["game"].data["h"])
#################################################################
# apply_edges() is called on all relation types
......@@ -200,38 +219,40 @@ def test_binary_op(idtype):
F.attach_grad(n1)
F.attach_grad(n2)
F.attach_grad(n3)
g.nodes['user'].data['h'] = n1
g.nodes['developer'].data['h'] = n2
g.nodes['game'].data['h'] = n3
g.nodes["user"].data["h"] = n1
g.nodes["developer"].data["h"] = n2
g.nodes["game"].data["h"] = n3
F.attach_grad(x1)
F.attach_grad(x2)
F.attach_grad(x3)
F.attach_grad(x4)
g['plays'].edata['h'] = x1
g['follows'].edata['h'] = x2
g['develops'].edata['h'] = x3
g['wishes'].edata['h'] = x4
g["plays"].edata["h"] = x1
g["follows"].edata["h"] = x2
g["develops"].edata["h"] = x3
g["wishes"].edata["h"] = x4
with F.record_grad():
g.apply_edges(builtin_msg('h', 'h', 'm'))
r2 = g['plays'].edata['m']
g.apply_edges(builtin_msg("h", "h", "m"))
r2 = g["plays"].edata["m"]
loss = F.sum(r2.view(-1), 0)
F.backward(loss)
n_grad2 = F.grad(g.nodes['game'].data['h'])
n_grad2 = F.grad(g.nodes["game"].data["h"])
# correctness check
def _print_error(a, b):
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())):
for i, (x, y) in enumerate(
zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())
):
if not np.allclose(x, y):
print('@{} {} v.s. {}'.format(i, x, y))
print("@{} {} v.s. {}".format(i, x, y))
if not F.allclose(r1, r2):
_print_error(r1, r2)
assert F.allclose(r1, r2)
if n_grad1 is not None or n_grad2 is not None:
if not F.allclose(n_grad1, n_grad2):
print('node grad')
print("node grad")
_print_error(n_grad1, n_grad2)
assert(F.allclose(n_grad1, n_grad2))
assert F.allclose(n_grad1, n_grad2)
target = ["u", "v", "e"]
for lhs, rhs in product(target, target):
......@@ -242,6 +263,6 @@ def test_binary_op(idtype):
_test(lhs, rhs, binary_op)
if __name__ == '__main__':
if __name__ == "__main__":
test_unary_copy_u()
test_unary_copy_e()
import backend as F
import os
import unittest
import backend as F
def test_set_default_backend():
default_dir = os.path.join(os.path.expanduser('~'), '.dgl_unit_test')
F.set_default_backend(default_dir, 'pytorch')
default_dir = os.path.join(os.path.expanduser("~"), ".dgl_unit_test")
F.set_default_backend(default_dir, "pytorch")
# make sure the config file was created
assert os.path.exists(os.path.join(default_dir, 'config.json'))
assert os.path.exists(os.path.join(default_dir, "config.json"))
This diff is collapsed.
import unittest
import backend as F
import dgl
from dgl.dataloading import NeighborSampler, negative_sampler, \
as_edge_prediction_sampler
from test_utils import parametrize_idtype
import dgl
from dgl.dataloading import (
NeighborSampler,
as_edge_prediction_sampler,
negative_sampler,
)
def create_test_graph(idtype):
# test heterograph from the docstring, plus a user -- wishes -- game relation
# 3 users, 2 games, 2 developers
......@@ -14,12 +20,16 @@ def create_test_graph(idtype):
# ('user', 'wishes', 'game'),
# ('developer', 'develops', 'game')])
g = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1], [1, 2]),
('user', 'plays', 'game'): ([0, 1, 2, 1], [0, 0, 1, 1]),
('user', 'wishes', 'game'): ([0, 2], [1, 0]),
('developer', 'develops', 'game'): ([0, 1], [0, 1])
}, idtype=idtype, device=F.ctx())
g = dgl.heterograph(
{
("user", "follows", "user"): ([0, 1], [1, 2]),
("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
("user", "wishes", "game"): ([0, 2], [1, 0]),
("developer", "develops", "game"): ([0, 1], [0, 1]),
},
idtype=idtype,
device=F.ctx(),
)
assert g.idtype == idtype
assert g.device == F.ctx()
return g
......@@ -28,14 +38,15 @@ def create_test_graph(idtype):
@parametrize_idtype
def test_edge_prediction_sampler(idtype):
g = create_test_graph(idtype)
sampler = NeighborSampler([10,10])
sampler = NeighborSampler([10, 10])
sampler = as_edge_prediction_sampler(
sampler, negative_sampler=negative_sampler.Uniform(1))
sampler, negative_sampler=negative_sampler.Uniform(1)
)
seeds = F.copy_to(F.arange(0, 2, dtype=idtype), ctx=F.ctx())
# just a smoke test to make sure we don't fail internal assertions
result = sampler.sample(g, {'follows': seeds})
result = sampler.sample(g, {"follows": seeds})
if __name__ == '__main__':
if __name__ == "__main__":
test_edge_prediction_sampler()
import dgl
from dgl.ops import edge_softmax
import dgl.function as fn
from collections import Counter
import math
import numpy as np
import scipy.sparse as ssp
import itertools
import math
import unittest
from collections import Counter
import backend as F
import networkx as nx
import unittest, pytest
from dgl import DGLError
import numpy as np
import pytest
import scipy.sparse as ssp
import test_utils
from test_utils import parametrize_idtype, get_cases
from scipy.sparse import rand
from test_utils import get_cases, parametrize_idtype
rfuncs = {'sum': fn.sum, 'max': fn.max, 'min': fn.min, 'mean': fn.mean}
fill_value = {'sum': 0, 'max': float("-inf")}
import dgl
import dgl.function as fn
from dgl import DGLError
from dgl.ops import edge_softmax
rfuncs = {"sum": fn.sum, "max": fn.max, "min": fn.min, "mean": fn.mean}
fill_value = {"sum": 0, "max": float("-inf")}
feat_size = 2
def create_test_heterograph(idtype):
# test heterograph from the docstring, plus a user -- wishes -- game relation
# 3 users, 2 games, 2 developers
......@@ -27,37 +31,57 @@ def create_test_heterograph(idtype):
# ('user', 'wishes', 'game'),
# ('developer', 'develops', 'game')])
g = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1, 2, 1, 1], [0, 0, 1, 1, 2]),
('user', 'plays', 'game'): ([0, 1, 2, 1], [0, 0, 1, 1]),
('user', 'wishes', 'game'): ([0, 1, 1], [0, 0, 1]),
('developer', 'develops', 'game'): ([0, 1, 0], [0, 1, 1]),
}, idtype=idtype, device=F.ctx())
g = dgl.heterograph(
{
("user", "follows", "user"): ([0, 1, 2, 1, 1], [0, 0, 1, 1, 2]),
("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
("user", "wishes", "game"): ([0, 1, 1], [0, 0, 1]),
("developer", "develops", "game"): ([0, 1, 0], [0, 1, 1]),
},
idtype=idtype,
device=F.ctx(),
)
assert g.idtype == idtype
assert g.device == F.ctx()
return g
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
def test_edge_softmax_unidirectional():
g = dgl.heterograph({
('A', 'AB', 'B'): ([1,2,3,1,2,3,1,2,3],[0,0,0,1,1,1,2,2,2]),
('B', 'BB', 'B'): ([0,1,2,0,1,2,0,1,2], [0,0,0,1,1,1,2,2,2])})
g = dgl.heterograph(
{
("A", "AB", "B"): (
[1, 2, 3, 1, 2, 3, 1, 2, 3],
[0, 0, 0, 1, 1, 1, 2, 2, 2],
),
("B", "BB", "B"): (
[0, 1, 2, 0, 1, 2, 0, 1, 2],
[0, 0, 0, 1, 1, 1, 2, 2, 2],
),
}
)
g = g.to(F.ctx())
g.edges['AB'].data['x'] = F.ones(9) * 2
g.edges['BB'].data['x'] = F.ones(9)
result = dgl.ops.edge_softmax(g, {'AB': g.edges['AB'].data['x'], 'BB': g.edges['BB'].data['x']})
g.edges["AB"].data["x"] = F.ones(9) * 2
g.edges["BB"].data["x"] = F.ones(9)
result = dgl.ops.edge_softmax(
g, {"AB": g.edges["AB"].data["x"], "BB": g.edges["BB"].data["x"]}
)
ab = result['A', 'AB', 'B']
bb = result['B', 'BB', 'B']
ab = result["A", "AB", "B"]
bb = result["B", "BB", "B"]
e2 = F.zeros_like(ab) + math.exp(2) / ((math.exp(2) + math.exp(1)) * 3)
e1 = F.zeros_like(bb) + math.exp(1) / ((math.exp(2) + math.exp(1)) * 3)
assert F.allclose(ab, e2)
assert F.allclose(bb, e1)
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@pytest.mark.parametrize('g', get_cases(['clique']))
@pytest.mark.parametrize('norm_by', ['src', 'dst'])
@unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
@pytest.mark.parametrize("g", get_cases(["clique"]))
@pytest.mark.parametrize("norm_by", ["src", "dst"])
# @pytest.mark.parametrize('shp', edge_softmax_shapes)
@parametrize_idtype
def test_edge_softmax(g, norm_by, idtype):
......@@ -65,20 +89,20 @@ def test_edge_softmax(g, norm_by, idtype):
g = create_test_heterograph(idtype)
x1 = F.randn((g.num_edges('plays'),feat_size))
x2 = F.randn((g.num_edges('follows'),feat_size))
x3 = F.randn((g.num_edges('develops'),feat_size))
x4 = F.randn((g.num_edges('wishes'),feat_size))
x1 = F.randn((g.num_edges("plays"), feat_size))
x2 = F.randn((g.num_edges("follows"), feat_size))
x3 = F.randn((g.num_edges("develops"), feat_size))
x4 = F.randn((g.num_edges("wishes"), feat_size))
F.attach_grad(F.clone(x1))
F.attach_grad(F.clone(x2))
F.attach_grad(F.clone(x3))
F.attach_grad(F.clone(x4))
g['plays'].edata['eid'] = x1
g['follows'].edata['eid'] = x2
g['develops'].edata['eid'] = x3
g['wishes'].edata['eid'] = x4
g["plays"].edata["eid"] = x1
g["follows"].edata["eid"] = x2
g["develops"].edata["eid"] = x3
g["wishes"].edata["eid"] = x4
#################################################################
# edge_softmax() on homogeneous graph
......@@ -89,12 +113,12 @@ def test_edge_softmax(g, norm_by, idtype):
hm_x = F.cat((x3, x2, x1, x4), 0)
hm_e = F.attach_grad(F.clone(hm_x))
score_hm = edge_softmax(hm_g, hm_e, norm_by=norm_by)
hm_g.edata['score'] = score_hm
hm_g.edata["score"] = score_hm
ht_g = dgl.to_heterogeneous(hm_g, g.ntypes, g.etypes)
r1 = ht_g.edata['score'][('user', 'plays', 'game')]
r2 = ht_g.edata['score'][('user', 'follows', 'user')]
r3 = ht_g.edata['score'][('developer', 'develops', 'game')]
r4 = ht_g.edata['score'][('user', 'wishes', 'game')]
r1 = ht_g.edata["score"][("user", "plays", "game")]
r2 = ht_g.edata["score"][("user", "follows", "user")]
r3 = ht_g.edata["score"][("developer", "develops", "game")]
r4 = ht_g.edata["score"][("user", "wishes", "game")]
F.backward(F.reduce_sum(r1) + F.reduce_sum(r2))
grad_edata_hm = F.grad(hm_e)
......@@ -106,18 +130,22 @@ def test_edge_softmax(g, norm_by, idtype):
e2 = F.attach_grad(F.clone(x2))
e3 = F.attach_grad(F.clone(x3))
e4 = F.attach_grad(F.clone(x4))
e = {('user', 'follows', 'user'): e2,
('user', 'plays', 'game'): e1,
('user', 'wishes', 'game'): e4,
('developer', 'develops', 'game'): e3}
e = {
("user", "follows", "user"): e2,
("user", "plays", "game"): e1,
("user", "wishes", "game"): e4,
("developer", "develops", "game"): e3,
}
with F.record_grad():
score = edge_softmax(g, e, norm_by=norm_by)
r5 = score[('user', 'plays', 'game')]
r6 = score[('user', 'follows', 'user')]
r7 = score[('developer', 'develops', 'game')]
r8 = score[('user', 'wishes', 'game')]
r5 = score[("user", "plays", "game")]
r6 = score[("user", "follows", "user")]
r7 = score[("developer", "develops", "game")]
r8 = score[("user", "wishes", "game")]
F.backward(F.reduce_sum(r5) + F.reduce_sum(r6))
grad_edata_ht = F.cat((F.grad(e3), F.grad(e2), F.grad(e1), F.grad(e4)), 0)
grad_edata_ht = F.cat(
(F.grad(e3), F.grad(e2), F.grad(e1), F.grad(e4)), 0
)
# correctness check
assert F.allclose(r1, r5)
assert F.allclose(r2, r6)
......@@ -125,5 +153,6 @@ def test_edge_softmax(g, norm_by, idtype):
assert F.allclose(r4, r8)
assert F.allclose(grad_edata_hm, grad_edata_ht)
if __name__ == '__main__':
if __name__ == "__main__":
test_edge_softmax_unidirectional()
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment